[Model] Port over CLIPVisionModel for VLMs (#5591)
This commit is contained in:
@@ -4,9 +4,7 @@ from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
|
||||
# transformers' impl.
|
||||
from transformers import CLIPVisionModel, LlavaNextConfig
|
||||
from transformers import LlavaNextConfig
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
from typing_extensions import NotRequired
|
||||
@@ -20,6 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
||||
@@ -121,7 +120,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
|
||||
if self.vision_language_config.image_input_type == (
|
||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
||||
else:
|
||||
raise TypeError("Image features are not supported by LLaVA-NeXT")
|
||||
|
||||
@@ -219,12 +218,11 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
|
||||
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
|
||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
|
||||
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
|
||||
output_hidden_states=True)
|
||||
|
||||
image_features = image_outputs.hidden_states[
|
||||
self.config.vision_feature_layer]
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values.to(vision_tower.device),
|
||||
self.config.vision_feature_layer)
|
||||
|
||||
return self._select_image_features(
|
||||
image_features,
|
||||
@@ -430,6 +428,9 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
# post_layernorm is not needed in CLIPVisionModel
|
||||
if "vision_model.post_layernorm" in name:
|
||||
continue
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
|
||||
Reference in New Issue
Block a user