[Model] Support SigLIP encoder and alternative decoders for LLaVA models (#7153)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
@@ -1,22 +1,70 @@
|
||||
from typing import Dict, List, Protocol, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.func import functional_call
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.loader import build_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.multimodal import BatchedTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
|
||||
"""
|
||||
Helper function to load weights for inner vLLM models.
|
||||
|
||||
See also:
|
||||
:ref:`init_vllm_registered_model`
|
||||
"""
|
||||
for name, loaded_weight in weights:
|
||||
name = name.split(".")
|
||||
if prefix == name.pop(0):
|
||||
name = ".".join(name)
|
||||
yield name, loaded_weight
|
||||
|
||||
|
||||
def init_vllm_registered_model(
|
||||
hf_config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
scheduler_config: Optional[SchedulerConfig] = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Helper function to initialize an inner model registered to vLLM,
|
||||
based on the arguments passed to the outer vLLM model.
|
||||
"""
|
||||
model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
|
||||
|
||||
return build_model(
|
||||
model_class,
|
||||
hf_config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config,
|
||||
multimodal_config=multimodal_config,
|
||||
scheduler_config=scheduler_config,
|
||||
)
|
||||
|
||||
|
||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: BatchedTensors,
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
"""
|
||||
Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
|
||||
in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
|
||||
Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
positions in ``inputs_embeds`` corresponding to placeholder image tokens in
|
||||
``input_ids``.
|
||||
|
||||
Note:
|
||||
This updates `inputs_embeds` in place.
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
mask = (input_ids == image_token_id)
|
||||
num_expected_tokens = mask.sum()
|
||||
|
||||
Reference in New Issue
Block a user