[Model] Port over CLIPVisionModel for VLMs (#5591)

This commit is contained in:
Roger Wang
2024-06-20 04:52:09 -07:00
committed by GitHub
parent 111af1fa2c
commit ad137cd111
9 changed files with 269 additions and 21 deletions

View File

@@ -4,9 +4,7 @@ from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
import torch
import torch.nn as nn
from PIL import Image
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaNextConfig
from transformers import LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
@@ -20,6 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
@@ -121,7 +120,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
self.vision_tower = CLIPVisionModel(config=config.vision_config)
else:
raise TypeError("Image features are not supported by LLaVA-NeXT")
@@ -219,12 +218,11 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
self.config.vision_feature_layer)
return self._select_image_features(
image_features,
@@ -430,6 +428,9 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)