2025-02-02 14:58:18 -05:00
# SPDX-License-Identifier: Apache-2.0
2025-06-03 11:20:17 -07:00
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2025-02-02 14:58:18 -05:00
2025-01-12 16:17:24 +08:00
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
""" Inference-only Deepseek-VL2 model compatible with HuggingFace weights. """
2025-10-05 15:06:22 +01:00
2025-01-12 16:17:24 +08:00
import math
2025-02-28 01:44:25 +08:00
from collections . abc import Iterable , Mapping , Sequence
2025-07-26 19:34:11 -07:00
from typing import Annotated , Literal , TypeAlias
2025-01-12 16:17:24 +08:00
import torch
import torch . nn as nn
import torch . nn . functional as F
from einops import rearrange , repeat
2025-01-18 13:59:39 +08:00
from transformers import BatchFeature
2025-01-12 16:17:24 +08:00
from vllm . config import VllmConfig
2025-10-03 03:59:10 -07:00
from vllm . config . multimodal import BaseDummyOptions
2025-07-25 13:45:16 +08:00
from vllm . distributed import get_tensor_model_parallel_world_size
2025-01-12 16:17:24 +08:00
from vllm . model_executor . layers . quantization import QuantizationConfig
2025-10-16 22:50:39 +01:00
from vllm . model_executor . models . transformers . utils import replace_linear_class
2025-01-12 16:17:24 +08:00
from vllm . multimodal import MULTIMODAL_REGISTRY
2025-04-11 03:32:14 +08:00
from vllm . multimodal . inputs import (
MultiModalDataDict ,
MultiModalFieldConfig ,
2025-10-02 23:17:35 +08:00
MultiModalKwargsItems ,
)
2025-01-12 16:17:24 +08:00
from vllm . multimodal . parse import (
ImageEmbeddingItems ,
ImageProcessorItems ,
ImageSize ,
MultiModalDataItems ,
)
2026-01-14 23:25:31 +08:00
from vllm . multimodal . processing import BaseDummyInputsBuilder
from vllm . multimodal . processing . processor import (
2025-01-12 16:17:24 +08:00
BaseMultiModalProcessor ,
2025-08-18 20:31:53 +08:00
BaseProcessingInfo ,
MultiModalProcessingInfo ,
2026-02-23 22:15:50 +08:00
ProcessorInputs ,
2025-04-29 09:40:35 +08:00
PromptReplacement ,
PromptUpdate ,
2026-02-23 22:15:50 +08:00
TimingContext ,
2025-04-29 09:40:35 +08:00
)
2025-01-12 16:17:24 +08:00
from vllm . sequence import IntermediateTensors
2025-12-02 13:33:37 +08:00
from vllm . tokenizers import cached_tokenizer_from_config
2025-01-12 16:17:24 +08:00
from vllm . transformers_utils . configs . deepseek_vl2 import (
DeepseekVLV2Config ,
MlpProjectorConfig ,
VisionEncoderConfig ,
)
2025-01-18 13:59:39 +08:00
from vllm . transformers_utils . processors . deepseek_vl2 import DeepseekVLV2Processor
2025-07-26 19:34:11 -07:00
from vllm . utils . tensor_schema import TensorSchema , TensorShape
2025-10-19 00:48:22 +08:00
from vllm . utils . torch_utils import set_default_torch_dtype
2025-01-12 16:17:24 +08:00
2025-03-14 15:59:56 +08:00
from . interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
2025-10-02 23:17:35 +08:00
from . utils import (
AutoWeightsLoader ,
WeightsMapper ,
2025-09-27 16:15:12 +08:00
init_vllm_registered_model ,
maybe_prefix ,
)
2025-01-12 16:17:24 +08:00
# The image token id may be various
_IMAGE_TOKEN = " <image> "
2025-07-26 19:34:11 -07:00
class DeepseekVL2ImagePixelInputs ( TensorSchema ) :
2025-01-12 16:17:24 +08:00
"""
2025-07-26 19:34:11 -07:00
Dimensions :
2025-10-02 23:17:35 +08:00
- bnp : Batch size * number of images * number of patches
2025-08-03 15:52:14 +08:00
- p : Number of patches
2025-07-26 19:34:11 -07:00
- c : Number of channels ( 3 )
- h : Height of each image
- w : Width of each image
2025-01-12 16:17:24 +08:00
"""
2025-10-05 15:06:22 +01:00
2025-07-26 19:34:11 -07:00
type : Literal [ " pixel_values " ]
2025-10-02 23:17:35 +08:00
data : Annotated [ torch . Tensor , TensorShape ( " bnp " , 3 , " h " , " w " , dynamic_dims = { " bnp " } ) ]
2025-07-26 19:34:11 -07:00
images_spatial_crop : Annotated [ torch . Tensor , TensorShape ( " bn " , 2 ) ]
2025-01-12 16:17:24 +08:00
2025-07-26 19:34:11 -07:00
class DeepseekVL2VImageEmbeddingInputs ( TensorSchema ) :
2025-01-12 16:17:24 +08:00
"""
2025-07-26 19:34:11 -07:00
Dimensions :
- bn : Batch size * number of images
- f : Image feature size
- h : Hidden size ( must match language model backbone )
"""
2025-10-05 15:06:22 +01:00
2025-07-26 19:34:11 -07:00
type : Literal [ " image_embeds " ]
data : Annotated [ torch . Tensor | list [ torch . Tensor ] , TensorShape ( " bn " , " f " , " h " ) ]
2025-01-12 16:17:24 +08:00
DeepseekVL2ImageInputs : TypeAlias = (
DeepseekVL2ImagePixelInputs | DeepseekVL2VImageEmbeddingInputs
2025-10-12 17:51:31 +01:00
)
2025-01-12 16:17:24 +08:00
class MlpProjector ( nn . Module ) :
def __init__ ( self , cfg : MlpProjectorConfig ) :
super ( ) . __init__ ( )
self . cfg = cfg
2025-10-22 22:59:15 +08:00
self . projector_type = cfg . projector_type
2025-01-12 16:17:24 +08:00
assert not cfg . token_pooling , " Token pooling is not supported currently. "
2025-10-22 22:59:15 +08:00
if self . projector_type == " downsample_mlp_gelu " :
2025-01-12 16:17:24 +08:00
mlp_depth = cfg . depth
mlp_ratio = cfg . mlp_ratio
modules = [
nn . Linear (
cfg . input_dim * cfg . downsample_ratio * cfg . downsample_ratio ,
cfg . n_embed * mlp_ratio ,
)
]
for _ in range ( 1 , mlp_depth - 1 ) :
modules . append ( nn . GELU ( ) )
modules . append (
nn . Linear ( cfg . n_embed * mlp_ratio , cfg . n_embed * mlp_ratio )
)
modules . append ( nn . GELU ( ) )
modules . append ( nn . Linear ( cfg . n_embed * mlp_ratio , cfg . n_embed ) )
modules = nn . Sequential ( * modules )
2025-10-22 22:59:15 +08:00
elif self . projector_type == " linear " :
modules = nn . Linear ( cfg . input_dim , cfg . n_embed )
2025-01-12 16:17:24 +08:00
else :
raise NotImplementedError (
f " Unsupported projector type: { cfg . projector_type } "
)
self . layers = modules
def forward ( self , x ) :
bs , hw , input_dim = x . shape
2025-10-22 22:59:15 +08:00
if self . projector_type == " downsample_mlp_gelu " :
h = w = int ( ( hw ) * * 0.5 )
""" compute padding """
if h % self . cfg . downsample_ratio :
pad = self . cfg . downsample_ratio - h % self . cfg . downsample_ratio
else :
pad = 0
x = x . reshape ( bs , h , w , input_dim )
if pad > 0 :
x = F . pad ( x , ( 0 , 0 , 0 , pad , 0 , pad ) , " constant " , 0 )
""" 4 to 1 concat """
x = x . permute ( 0 , 3 , 1 , 2 ) # B, C, H, W
x = F . unfold (
x ,
kernel_size = self . cfg . downsample_ratio ,
stride = self . cfg . downsample_ratio ,
padding = 0 ,
) # B, C*4, HW // 4
x = x . permute ( 0 , 2 , 1 )
2025-01-12 16:17:24 +08:00
return self . layers ( x )
class DeepseekVL2ProcessingInfo ( BaseProcessingInfo ) :
def get_hf_config ( self ) :
return self . ctx . get_hf_config ( DeepseekVLV2Config )
2025-02-19 21:13:50 +08:00
def get_hf_processor ( self , * * kwargs : object ) :
return self . ctx . get_hf_processor ( DeepseekVLV2Processor , * * kwargs )
2025-01-12 16:17:24 +08:00
def get_supported_mm_limits ( self ) - > Mapping [ str , int | None ] :
return { " image " : None }
2025-02-25 22:03:02 +08:00
def get_num_image_tokens (
self , * , image_width : int , image_height : int , cropping : bool = True
) - > int :
2025-01-12 16:17:24 +08:00
hf_processor = self . get_hf_processor ( )
image_size = hf_processor . image_size
patch_size = hf_processor . patch_size
downsample_ratio = hf_processor . downsample_ratio
2025-02-25 22:03:02 +08:00
if cropping :
best_width , best_height = hf_processor . select_best_resolution (
( image_width , image_height )
2025-10-05 15:06:22 +01:00
)
2025-02-25 22:03:02 +08:00
num_width_tiles , num_height_tiles = (
best_width / / image_size ,
best_height / / image_size ,
)
else :
num_width_tiles = num_height_tiles = 1
2025-01-12 16:17:24 +08:00
h = w = math . ceil ( ( image_size / / patch_size ) / downsample_ratio )
global_views_tokens = h * ( w + 1 )
local_views_tokens = ( num_height_tiles * h ) * ( num_width_tiles * w + 1 )
return global_views_tokens + local_views_tokens + 1
def get_image_size_with_most_features ( self ) - > ImageSize :
hf_config = self . get_hf_config ( )
candidate_resolutions = hf_config . candidate_resolutions
height , width = max (
candidate_resolutions ,
key = lambda x : self . get_num_image_tokens (
image_width = x [ 1 ] , image_height = x [ 0 ]
2025-10-05 15:06:22 +01:00
) ,
2025-01-12 16:17:24 +08:00
)
return ImageSize ( width = width , height = height )
class DeepseekVL2DummyInputsBuilder ( BaseDummyInputsBuilder [ DeepseekVL2ProcessingInfo ] ) :
2025-04-11 03:32:14 +08:00
def get_dummy_text ( self , mm_counts : Mapping [ str , int ] ) - > str :
num_images = mm_counts . get ( " image " , 0 )
processor = self . info . get_hf_processor ( )
image_token = processor . image_token
return image_token * num_images
def get_dummy_mm_data (
2025-01-12 16:17:24 +08:00
self ,
seq_len : int ,
mm_counts : Mapping [ str , int ] ,
2026-02-23 12:55:27 +08:00
mm_options : Mapping [ str , BaseDummyOptions ] ,
2025-04-11 03:32:14 +08:00
) - > MultiModalDataDict :
2025-01-12 16:17:24 +08:00
num_images = mm_counts . get ( " image " , 0 )
max_image_size = self . info . get_image_size_with_most_features ( )
2026-02-23 12:55:27 +08:00
image_overrides = mm_options . get ( " image " )
2025-10-03 03:59:10 -07:00
2025-04-11 03:32:14 +08:00
return {
2025-01-12 16:17:24 +08:00
" image " : self . _get_dummy_images (
width = max_image_size . width ,
height = max_image_size . height ,
2025-10-03 03:59:10 -07:00
num_images = num_images ,
overrides = image_overrides ,
)
2025-01-12 16:17:24 +08:00
}
class DeepseekVL2MultiModalProcessor (
BaseMultiModalProcessor [ DeepseekVL2ProcessingInfo ]
) :
def _call_hf_processor (
self ,
prompt : str ,
mm_data : Mapping [ str , object ] ,
mm_kwargs : Mapping [ str , object ] ,
2025-06-30 13:26:49 -04:00
tok_kwargs : Mapping [ str , object ] ,
2025-01-12 16:17:24 +08:00
) - > BatchFeature :
2025-08-01 13:44:10 +08:00
if not mm_data :
2025-01-12 16:17:24 +08:00
tokenizer = self . info . get_tokenizer ( )
2025-08-01 13:44:10 +08:00
return tokenizer ( prompt , add_special_tokens = True , return_tensors = " pt " )
processed_outputs = super ( ) . _call_hf_processor (
prompt = prompt ,
mm_data = mm_data ,
mm_kwargs = mm_kwargs ,
tok_kwargs = tok_kwargs ,
)
2025-10-02 23:17:35 +08:00
processed_outputs [ " num_patches " ] = (
processed_outputs [ " images_spatial_crop " ] . prod ( - 1 ) + 1
)
2025-01-12 16:17:24 +08:00
return processed_outputs
def _get_mm_fields_config (
self ,
hf_inputs : BatchFeature ,
hf_processor_mm_kwargs : Mapping [ str , object ] ,
) - > Mapping [ str , MultiModalFieldConfig ] :
2025-10-02 23:17:35 +08:00
num_patches = hf_inputs . get ( " num_patches " , torch . empty ( 0 ) )
2025-01-12 16:17:24 +08:00
return dict (
2025-10-02 23:17:35 +08:00
pixel_values = MultiModalFieldConfig . flat_from_sizes ( " image " , num_patches ) ,
2025-01-12 16:17:24 +08:00
images_spatial_crop = MultiModalFieldConfig . batched ( " image " ) ,
image_embeds = MultiModalFieldConfig . batched ( " image " ) ,
)
2025-02-28 01:44:25 +08:00
def _get_prompt_updates (
2025-01-12 16:17:24 +08:00
self ,
mm_items : MultiModalDataItems ,
hf_processor_mm_kwargs : Mapping [ str , object ] ,
2025-08-18 17:52:00 +08:00
out_mm_kwargs : MultiModalKwargsItems ,
2025-02-28 01:44:25 +08:00
) - > Sequence [ PromptUpdate ] :
2025-01-22 19:08:31 +08:00
hf_processor = self . info . get_hf_processor ( * * hf_processor_mm_kwargs )
image_token_id = hf_processor . image_token_id
assert isinstance ( image_token_id , int )
2025-01-12 16:17:24 +08:00
def get_replacement_deepseek_vl2 ( item_idx : int ) :
images = mm_items . get_items (
" image " , ( ImageEmbeddingItems , ImageProcessorItems )
2025-10-05 15:06:22 +01:00
)
2025-01-12 16:17:24 +08:00
if isinstance ( images , ImageEmbeddingItems ) :
num_image_tokens = images . get_feature_size ( item_idx )
else :
image_size = images . get_image_size ( item_idx )
num_image_tokens = self . info . get_num_image_tokens (
image_width = image_size . width ,
image_height = image_size . height ,
2025-02-25 22:03:02 +08:00
cropping = len ( images ) < = 2 ,
2025-01-12 16:17:24 +08:00
)
return [ image_token_id ] * num_image_tokens
return [
PromptReplacement (
modality = " image " ,
target = [ image_token_id ] ,
replacement = get_replacement_deepseek_vl2 ,
)
]
2025-03-07 18:33:38 +08:00
def _cached_apply_hf_processor (
self ,
2026-02-23 22:15:50 +08:00
inputs : ProcessorInputs ,
timing_ctx : TimingContext ,
2025-08-18 20:31:53 +08:00
) - > tuple [ list [ int ] , MultiModalProcessingInfo , bool ] :
2025-03-07 18:33:38 +08:00
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
2026-02-23 22:15:50 +08:00
if inputs . mm_data_items . get_count ( " image " , strict = False ) > 2 :
return self . _apply_hf_processor ( inputs , timing_ctx )
2025-03-07 18:33:38 +08:00
2026-02-23 22:15:50 +08:00
return super ( ) . _cached_apply_hf_processor ( inputs , timing_ctx )
2025-03-07 18:33:38 +08:00
2025-01-12 16:17:24 +08:00
@MULTIMODAL_REGISTRY.register_processor (
DeepseekVL2MultiModalProcessor ,
info = DeepseekVL2ProcessingInfo ,
dummy_inputs = DeepseekVL2DummyInputsBuilder ,
)
class DeepseekVLV2ForCausalLM ( nn . Module , SupportsMultiModal , SupportsPP ) :
hf_to_vllm_mapper = WeightsMapper (
orig_to_new_prefix = {
" language. " : " language_model. " ,
}
)
2025-07-03 16:18:30 +08:00
@classmethod
def get_placeholder_str ( cls , modality : str , i : int ) - > str | None :
if modality . startswith ( " image " ) :
return " <image> "
raise ValueError ( " Only image modality is supported " )
2025-01-12 16:17:24 +08:00
def __init__ ( self , * , vllm_config : VllmConfig , prefix : str = " " ) :
super ( ) . __init__ ( )
config : DeepseekVLV2Config = vllm_config . model_config . hf_config
quant_config = vllm_config . quant_config
multimodal_config = vllm_config . model_config . multimodal_config
self . config = config
self . multimodal_config = multimodal_config
self . vision_config = config . vision_config
self . projector_config = config . projector_config
self . text_config = config . text_config
2025-12-07 16:00:22 +08:00
model_config = vllm_config . model_config
tokenizer = cached_tokenizer_from_config ( model_config )
2025-09-27 16:15:12 +08:00
self . image_token_id : int = tokenizer . vocab [ _IMAGE_TOKEN ]
2025-01-12 16:17:24 +08:00
2026-01-20 14:12:42 +08:00
with self . _mark_tower_model ( vllm_config , " image " ) :
self . vision = self . _init_vision_module (
self . vision_config , quant_config , maybe_prefix ( prefix , " vision " )
)
2025-01-12 16:17:24 +08:00
2026-01-20 14:12:42 +08:00
self . projector = MlpProjector ( self . projector_config )
self . tile_tag = config . tile_tag
self . global_view_pos = config . global_view_pos
2025-01-12 16:17:24 +08:00
2026-01-20 14:12:42 +08:00
# special token for image token sequence format
embed_std = 1 / torch . sqrt (
torch . tensor ( self . projector_config . n_embed , dtype = torch . float32 )
2025-01-12 16:17:24 +08:00
)
2026-01-20 14:12:42 +08:00
if self . tile_tag == " 2D " :
# <|view_seperator|>, <|\n|>
self . image_newline = nn . Parameter (
torch . randn ( self . projector_config . n_embed ) * embed_std
)
# This is a typo in original implementation
self . view_seperator = nn . Parameter (
torch . randn ( self . projector_config . n_embed ) * embed_std
)
else :
raise ValueError (
f " Only 2D tile_tag is supported currently, got: { self . tile_tag } "
)
2025-01-12 16:17:24 +08:00
2026-01-20 14:12:42 +08:00
with self . _mark_language_model ( vllm_config ) :
self . language_model = init_vllm_registered_model (
vllm_config = vllm_config ,
hf_config = self . text_config ,
prefix = maybe_prefix ( prefix , " language " ) ,
)
2025-01-12 16:17:24 +08:00
self . make_empty_intermediate_tensors = (
self . language_model . make_empty_intermediate_tensors
)
2025-07-25 13:45:16 +08:00
def _get_parent_and_attr ( self , root : torch . nn . Module , dotted_name : str ) :
""" Return (parent_module, final_attr_name) for a dotted module path. """
names = dotted_name . split ( " . " )
parent = root
for n in names [ : - 1 ] :
parent = getattr ( parent , n )
return parent , names [ - 1 ]
# patch for timm ViT instance to support tensor parallel
def patch_vit_for_tp ( self , vit : torch . nn . Module , quant_config : QuantizationConfig ) :
try :
import timm
except ImportError as e :
raise ImportError ( " Please install timm " ) from e
for name , module in vit . named_modules ( ) :
if isinstance ( module , nn . Linear ) :
parent , attr_name = self . _get_parent_and_attr ( vit , name )
if isinstance ( parent , timm . layers . Mlp ) and attr_name == " fc1 " :
2025-08-28 08:42:44 +05:30
new_linear = replace_linear_class (
module , " colwise " , quant_config , prefix = name
)
2025-07-25 13:45:16 +08:00
setattr ( parent , attr_name , new_linear )
elif isinstance ( parent , timm . layers . Mlp ) and attr_name == " fc2 " :
2025-08-28 08:42:44 +05:30
new_linear = replace_linear_class (
module , " rowwise " , quant_config , prefix = name
)
2025-07-25 13:45:16 +08:00
setattr ( parent , attr_name , new_linear )
return vit
2025-01-12 16:17:24 +08:00
def _init_vision_module (
self ,
vision_config : VisionEncoderConfig ,
quant_config : QuantizationConfig | None ,
prefix : str = " " ,
) - > nn . Module :
# TODO: refactor vision model through timm wrapper from transformers
try :
import timm
2025-07-25 13:45:16 +08:00
except ImportError as e :
raise ImportError ( " Please install timm " ) from e
2025-01-12 16:17:24 +08:00
with set_default_torch_dtype ( torch . float16 ) :
model = timm . create_model (
" vit_so400m_patch14_siglip_384.webli " ,
pretrained = False ,
num_classes = 0 ,
dynamic_img_size = True ,
dynamic_img_pad = True ,
)
2025-07-25 13:45:16 +08:00
if get_tensor_model_parallel_world_size ( ) > 1 :
model = self . patch_vit_for_tp ( model , quant_config )
2025-01-12 16:17:24 +08:00
model = model . to ( dtype = torch . get_default_dtype ( ) )
return model
def _parse_and_validate_image_input (
self , * * kwargs : object
) - > DeepseekVL2ImageInputs | None :
pixel_values = kwargs . pop ( " pixel_values " , None )
images_spatial_crop = kwargs . pop ( " images_spatial_crop " , None )
image_embeds = kwargs . pop ( " image_embeds " , None )
if pixel_values is None and image_embeds is None :
return None
if pixel_values is not None :
2025-07-26 19:34:11 -07:00
expected_h = expected_w = self . vision_config . image_size
2025-10-02 23:17:35 +08:00
return DeepseekVL2ImagePixelInputs (
type = " pixel_values " ,
data = pixel_values ,
images_spatial_crop = images_spatial_crop ,
resolve_bindings = {
" h " : expected_h ,
" w " : expected_w ,
} ,
)
2025-01-12 16:17:24 +08:00
if image_embeds is not None :
return DeepseekVL2VImageEmbeddingInputs (
type = " image_embeds " ,
2025-10-02 23:17:35 +08:00
data = image_embeds ,
2025-01-12 16:17:24 +08:00
)
raise AssertionError ( " This line should be unreachable. " )
def _pixel_values_to_embedding (
self ,
2025-10-02 23:17:35 +08:00
pixel_values : torch . Tensor ,
2025-01-12 16:17:24 +08:00
images_spatial_crop : torch . Tensor ,
2025-10-02 23:17:35 +08:00
) - > list [ torch . Tensor ] :
2025-01-12 16:17:24 +08:00
# [batch_all_tiles, vit_seq_len, c]
2025-10-02 23:17:35 +08:00
images_feature = self . vision . forward_features ( pixel_values )
2025-01-12 16:17:24 +08:00
# [batch_all_tiles, hw, D]
images_embeds = self . projector ( images_feature )
_ , hw , n_dim = images_embeds . shape
h = w = int ( hw * * 0.5 )
2025-03-31 01:47:57 +08:00
# fill image token based on self.tile_tag & self.global_view_pos
2025-01-12 16:17:24 +08:00
tile_index = 0
vision_embeddings = [ ]
for jdx in range ( images_spatial_crop . size ( 0 ) ) :
# extra global & local features
num_width_tiles , num_height_tiles = images_spatial_crop [ jdx ]
if num_width_tiles == 0 or num_height_tiles == 0 :
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds [ tile_index ]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds [
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index + = num_tiles_in_image + 1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features . view ( h , w , n_dim )
# [D] -> [h, 1, D]
new_lines_in_global = repeat ( self . image_newline , " d -> h 1 d " , h = h )
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch . cat ( [ global_features , new_lines_in_global ] , dim = 1 )
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features . view ( - 1 , n_dim )
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange (
local_features ,
" (th tw) (h w) d -> (th h) (tw w) d " ,
th = num_height_tiles ,
tw = num_width_tiles ,
h = h ,
w = w ,
)
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat (
self . image_newline , " d -> (th h) 1 d " , th = num_height_tiles , h = h
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch . cat ( [ local_features , new_lines_in_local ] , dim = 1 )
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features . view ( - 1 , n_dim )
# merge global and local tiles
if self . global_view_pos == " head " :
global_local_features = torch . cat (
[
global_features ,
2025-07-08 08:07:14 -07:00
self . view_seperator [ None , : ] ,
2025-01-12 16:17:24 +08:00
local_features ,
]
)
else :
global_local_features = torch . cat (
[
local_features ,
2025-07-08 08:07:14 -07:00
self . view_seperator [ None , : ] ,
2025-01-12 16:17:24 +08:00
global_features ,
]
)
vision_embeddings . append ( global_local_features )
return vision_embeddings
def _process_image_input (
2025-10-02 23:17:35 +08:00
self , image_input : DeepseekVL2ImageInputs
2025-11-26 21:00:15 +08:00
) - > torch . Tensor | list [ torch . Tensor ] :
2025-01-12 16:17:24 +08:00
if image_input [ " type " ] == " image_embeds " :
2025-11-26 21:00:15 +08:00
return image_input [ " data " ]
2025-01-12 16:17:24 +08:00
pixel_values = image_input [ " data " ]
images_spatial_crop = image_input [ " images_spatial_crop " ]
return self . _pixel_values_to_embedding (
pixel_values = pixel_values , images_spatial_crop = images_spatial_crop
)
2025-11-13 03:14:33 +00:00
def embed_multimodal ( self , * * kwargs : object ) - > MultiModalEmbeddings :
2025-01-12 16:17:24 +08:00
image_input = self . _parse_and_validate_image_input ( * * kwargs )
if image_input is None :
2025-06-16 13:32:15 -04:00
return [ ]
2025-01-12 16:17:24 +08:00
vision_embeddings = self . _process_image_input ( image_input )
return vision_embeddings
def forward (
self ,
2026-01-26 22:02:10 +08:00
input_ids : torch . Tensor | None ,
2025-01-12 16:17:24 +08:00
positions : torch . Tensor ,
intermediate_tensors : IntermediateTensors | None = None ,
inputs_embeds : torch . Tensor | None = None ,
* * kwargs : object ,
) :
if intermediate_tensors is not None :
inputs_embeds = None
hidden_states = self . language_model (
input_ids , positions , intermediate_tensors , inputs_embeds = inputs_embeds
)
return hidden_states
def compute_logits (
self ,
hidden_states : torch . Tensor ,
) - > torch . Tensor | None :
2025-09-21 10:37:11 -07:00
return self . language_model . compute_logits ( hidden_states )
2025-01-12 16:17:24 +08:00
2025-05-15 06:06:50 +01:00
def load_weights ( self , weights : Iterable [ tuple [ str , torch . Tensor ] ] ) - > set [ str ] :
2025-01-12 16:17:24 +08:00
loader = AutoWeightsLoader ( self )
autoloaded_weights = loader . load_weights ( weights , mapper = self . hf_to_vllm_mapper )
return autoloaded_weights