[Model] Port over CLIPVisionModel for VLMs (#5591)

This commit is contained in:
Roger Wang
2024-06-20 04:52:09 -07:00
committed by GitHub
parent 111af1fa2c
commit ad137cd111
9 changed files with 269 additions and 21 deletions

View File

@@ -17,7 +17,7 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
@@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -70,9 +71,10 @@ class Phi3ImageEmbeddingBase(nn.Module):
LAYER_IDX = self.layer_idx
TYPE_FEATURE = self.type_feature
img_processor_output = self.img_processor(img_embeds,
output_hidden_states=True)
img_feature = img_processor_output.hidden_states[LAYER_IDX]
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the img_processor
img_feature = self.img_processor(img_embeds,
vision_feature_layer=LAYER_IDX)
if TYPE_FEATURE == "patch":
patch_feature = img_feature[:, 1:]
@@ -352,6 +354,9 @@ class Phi3VForCausalLM(VisionLanguageModelBase):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)