[Model] Port over CLIPVisionModel for VLMs (#5591)
This commit is contained in:
@@ -17,7 +17,7 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
|
||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
@@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@@ -70,9 +71,10 @@ class Phi3ImageEmbeddingBase(nn.Module):
|
||||
LAYER_IDX = self.layer_idx
|
||||
TYPE_FEATURE = self.type_feature
|
||||
|
||||
img_processor_output = self.img_processor(img_embeds,
|
||||
output_hidden_states=True)
|
||||
img_feature = img_processor_output.hidden_states[LAYER_IDX]
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the img_processor
|
||||
img_feature = self.img_processor(img_embeds,
|
||||
vision_feature_layer=LAYER_IDX)
|
||||
|
||||
if TYPE_FEATURE == "patch":
|
||||
patch_feature = img_feature[:, 1:]
|
||||
@@ -352,6 +354,9 @@ class Phi3VForCausalLM(VisionLanguageModelBase):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
# post_layernorm is not needed in CLIPVisionModel
|
||||
if "vision_model.post_layernorm" in name:
|
||||
continue
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
|
||||
Reference in New Issue
Block a user