[Model] Port over CLIPVisionModel for VLMs (#5591)
This commit is contained in:
@@ -2,9 +2,7 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
|
||||
# transformers' impl.
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from transformers import LlavaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
@@ -15,6 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -189,12 +188,11 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
|
||||
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
|
||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
|
||||
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
|
||||
output_hidden_states=True)
|
||||
|
||||
image_features = image_outputs.hidden_states[
|
||||
self.config.vision_feature_layer]
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values.to(vision_tower.device),
|
||||
self.config.vision_feature_layer)
|
||||
|
||||
return self._select_image_features(
|
||||
image_features,
|
||||
@@ -317,6 +315,9 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
# post_layernorm is not needed in CLIPVisionModel
|
||||
if "vision_model.post_layernorm" in name:
|
||||
continue
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
|
||||
Reference in New Issue
Block a user