[VLM] Clean up Phi-4-MM ViT implementation (#14812)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,7 +11,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import PretrainedConfig, SiglipVisionConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@@ -32,10 +32,10 @@ from vllm.multimodal.inputs import MultiModalInputs, NestedTensors
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal
|
||||
from .phi4mm_audio import AudioEmbedding
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
from .vision_siglip_navit import get_siglip_vision_model
|
||||
|
||||
# <|endoftext10|> (see vocab.json in hf model)
|
||||
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
|
||||
@@ -339,6 +339,33 @@ def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size):
|
||||
return data
|
||||
|
||||
|
||||
def get_navit_vision_model(layer_idx: int = -1, **kwargs):
|
||||
vision_config = {
|
||||
"hidden_size": 1152,
|
||||
"image_size": 448,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
}
|
||||
|
||||
model_config = SiglipVisionConfig(**vision_config, **kwargs)
|
||||
if layer_idx < 0:
|
||||
num_hidden_layers = model_config.num_hidden_layers \
|
||||
+ layer_idx + 1
|
||||
else:
|
||||
num_hidden_layers = layer_idx + 1
|
||||
|
||||
vision_model = Idefics2VisionTransformer(
|
||||
config=model_config,
|
||||
require_post_norm=False,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
)
|
||||
|
||||
return vision_model
|
||||
|
||||
|
||||
class Phi4MMImageEncoder(nn.Module):
|
||||
"""Image embedding."""
|
||||
|
||||
@@ -362,8 +389,7 @@ class Phi4MMImageEncoder(nn.Module):
|
||||
self.layer_idx = -2
|
||||
self.type_feature = 'patch'
|
||||
|
||||
self.img_processor = get_siglip_vision_model(
|
||||
_flash_attn_2_enabled=True)
|
||||
self.img_processor = get_navit_vision_model(layer_idx=self.layer_idx)
|
||||
|
||||
pe_weight = self.img_processor.embeddings.position_embedding.weight
|
||||
L, D = pe_weight.size()
|
||||
@@ -430,16 +456,11 @@ class Phi4MMImageEncoder(nn.Module):
|
||||
def get_img_features(self,
|
||||
img_embeds: torch.FloatTensor,
|
||||
attention_mask=None) -> torch.FloatTensor:
|
||||
LAYER_IDX = self.layer_idx
|
||||
TYPE_FEATURE = self.type_feature
|
||||
|
||||
img_processor_output = self.img_processor(
|
||||
img_embeds,
|
||||
output_hidden_states=True,
|
||||
patch_attention_mask=attention_mask)
|
||||
img_feature = img_processor_output.hidden_states[LAYER_IDX]
|
||||
img_feature = self.img_processor(img_embeds,
|
||||
patch_attention_mask=attention_mask)
|
||||
|
||||
if TYPE_FEATURE == "patch":
|
||||
if self.type_feature == "patch":
|
||||
patch_feature = img_feature
|
||||
|
||||
use_token_compression = self.image_token_compression is not None
|
||||
|
||||
Reference in New Issue
Block a user