[VLM] Clean up Phi-4-MM ViT implementation (#14812)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Isotr0py
2025-03-16 09:53:52 +08:00
committed by GitHub
parent 3453b964a3
commit def232e122
7 changed files with 316 additions and 1988 deletions

View File

@@ -11,7 +11,7 @@ import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig
from transformers import PretrainedConfig, SiglipVisionConfig
from transformers.utils import logging
from vllm.config import VllmConfig
@@ -32,10 +32,10 @@ from vllm.multimodal.inputs import MultiModalInputs, NestedTensors
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA, SupportsMultiModal
from .phi4mm_audio import AudioEmbedding
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
from .vision_siglip_navit import get_siglip_vision_model
# <|endoftext10|> (see vocab.json in hf model)
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
@@ -339,6 +339,33 @@ def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size):
return data
def get_navit_vision_model(layer_idx: int = -1, **kwargs):
vision_config = {
"hidden_size": 1152,
"image_size": 448,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
}
model_config = SiglipVisionConfig(**vision_config, **kwargs)
if layer_idx < 0:
num_hidden_layers = model_config.num_hidden_layers \
+ layer_idx + 1
else:
num_hidden_layers = layer_idx + 1
vision_model = Idefics2VisionTransformer(
config=model_config,
require_post_norm=False,
num_hidden_layers_override=num_hidden_layers,
)
return vision_model
class Phi4MMImageEncoder(nn.Module):
"""Image embedding."""
@@ -362,8 +389,7 @@ class Phi4MMImageEncoder(nn.Module):
self.layer_idx = -2
self.type_feature = 'patch'
self.img_processor = get_siglip_vision_model(
_flash_attn_2_enabled=True)
self.img_processor = get_navit_vision_model(layer_idx=self.layer_idx)
pe_weight = self.img_processor.embeddings.position_embedding.weight
L, D = pe_weight.size()
@@ -430,16 +456,11 @@ class Phi4MMImageEncoder(nn.Module):
def get_img_features(self,
img_embeds: torch.FloatTensor,
attention_mask=None) -> torch.FloatTensor:
LAYER_IDX = self.layer_idx
TYPE_FEATURE = self.type_feature
img_processor_output = self.img_processor(
img_embeds,
output_hidden_states=True,
patch_attention_mask=attention_mask)
img_feature = img_processor_output.hidden_states[LAYER_IDX]
img_feature = self.img_processor(img_embeds,
patch_attention_mask=attention_mask)
if TYPE_FEATURE == "patch":
if self.type_feature == "patch":
patch_feature = img_feature
use_token_compression = self.image_token_compression is not None