[Model] Port over CLIPVisionModel for VLMs (#5591)

This commit is contained in:
Roger Wang
2024-06-20 04:52:09 -07:00
committed by GitHub
parent 111af1fa2c
commit ad137cd111
9 changed files with 269 additions and 21 deletions

View File

@@ -2,9 +2,7 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaConfig
from transformers import LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
@@ -15,6 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -189,12 +188,11 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
self.config.vision_feature_layer)
return self._select_image_features(
image_features,
@@ -317,6 +315,9 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)