2025-02-02 14:58:18 -05:00
# SPDX-License-Identifier: Apache-2.0
2025-06-03 11:20:17 -07:00
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2025-02-02 14:58:18 -05:00
2025-01-12 16:17:24 +08:00
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
""" Inference-only Deepseek-VL2 model compatible with HuggingFace weights. """
import math
2025-02-28 01:44:25 +08:00
from collections . abc import Iterable , Mapping , Sequence
2025-05-15 06:06:50 +01:00
from typing import Literal , Optional , TypedDict , Union
2025-01-12 16:17:24 +08:00
import torch
import torch . nn as nn
import torch . nn . functional as F
from einops import rearrange , repeat
2025-01-18 13:59:39 +08:00
from transformers import BatchFeature
2025-01-12 16:17:24 +08:00
from vllm . config import VllmConfig
2025-07-25 13:45:16 +08:00
from vllm . distributed import get_tensor_model_parallel_world_size
2025-01-12 16:17:24 +08:00
from vllm . model_executor import SamplingMetadata
from vllm . model_executor . layers . quantization import QuantizationConfig
from vllm . model_executor . model_loader . utils import set_default_torch_dtype
2025-07-25 13:45:16 +08:00
from vllm . model_executor . models . transformers import replace_linear_class
2025-01-12 16:17:24 +08:00
from vllm . multimodal import MULTIMODAL_REGISTRY
2025-04-11 03:32:14 +08:00
from vllm . multimodal . inputs import ( MultiModalDataDict , MultiModalFieldConfig ,
MultiModalKwargs , NestedTensors )
2025-01-12 16:17:24 +08:00
from vllm . multimodal . parse import ( ImageEmbeddingItems , ImageProcessorItems ,
ImageSize , MultiModalDataItems )
from vllm . multimodal . processing import ( BaseMultiModalProcessor ,
2025-04-29 09:40:35 +08:00
BaseProcessingInfo , MultiModalHashes ,
PromptReplacement , PromptUpdate )
2025-04-11 03:32:14 +08:00
from vllm . multimodal . profiling import BaseDummyInputsBuilder
2025-01-12 16:17:24 +08:00
from vllm . sequence import IntermediateTensors
from vllm . transformers_utils . configs . deepseek_vl2 import ( DeepseekVLV2Config ,
MlpProjectorConfig ,
VisionEncoderConfig )
2025-01-18 13:59:39 +08:00
from vllm . transformers_utils . processors . deepseek_vl2 import (
DeepseekVLV2Processor )
2025-02-19 21:13:50 +08:00
from vllm . transformers_utils . tokenizer import cached_tokenizer_from_config
2025-01-12 16:17:24 +08:00
from vllm . utils import is_list_of
2025-03-14 15:59:56 +08:00
from . interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
2025-01-12 16:17:24 +08:00
from . utils import ( AutoWeightsLoader , WeightsMapper , flatten_bn ,
init_vllm_registered_model , maybe_prefix ,
merge_multimodal_embeddings )
# The image token id may be various
_IMAGE_TOKEN = " <image> "
class DeepseekVL2ImagePixelInputs ( TypedDict ) :
type : Literal [ " pixel_values " ]
2025-05-15 06:06:50 +01:00
data : Union [ torch . Tensor , list [ torch . Tensor ] ]
2025-01-12 16:17:24 +08:00
"""
Shape : ` ( batch_size * num_images , num_channels , height , width ) `
"""
images_spatial_crop : torch . Tensor
"""
Shape : ` ( batch_size * num_images , 2 ) `
"""
class DeepseekVL2VImageEmbeddingInputs ( TypedDict ) :
type : Literal [ " image_embeds " ]
2025-05-15 06:06:50 +01:00
data : Union [ torch . Tensor , list [ torch . Tensor ] ]
2025-01-12 16:17:24 +08:00
""" Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
` hidden_size ` must match the hidden size of language model backbone .
"""
DeepseekVL2ImageInputs = Union [ DeepseekVL2ImagePixelInputs ,
DeepseekVL2VImageEmbeddingInputs ]
class MlpProjector ( nn . Module ) :
def __init__ ( self , cfg : MlpProjectorConfig ) :
super ( ) . __init__ ( )
self . cfg = cfg
assert not cfg . token_pooling , (
" Token pooling is not supported currently. " )
if cfg . projector_type == " downsample_mlp_gelu " :
mlp_depth = cfg . depth
mlp_ratio = cfg . mlp_ratio
modules = [
nn . Linear (
cfg . input_dim * cfg . downsample_ratio *
cfg . downsample_ratio , cfg . n_embed * mlp_ratio )
]
for _ in range ( 1 , mlp_depth - 1 ) :
modules . append ( nn . GELU ( ) )
modules . append (
nn . Linear ( cfg . n_embed * mlp_ratio ,
cfg . n_embed * mlp_ratio ) )
modules . append ( nn . GELU ( ) )
modules . append ( nn . Linear ( cfg . n_embed * mlp_ratio , cfg . n_embed ) )
modules = nn . Sequential ( * modules )
else :
raise NotImplementedError (
f " Unsupported projector type: { cfg . projector_type } " )
self . layers = modules
def forward ( self , x ) :
bs , hw , input_dim = x . shape
h = w = int ( ( hw ) * * 0.5 )
""" compute padding """
if h % self . cfg . downsample_ratio :
pad = self . cfg . downsample_ratio - h % self . cfg . downsample_ratio
else :
pad = 0
x = x . reshape ( bs , h , w , input_dim )
if pad > 0 :
x = F . pad ( x , ( 0 , 0 , 0 , pad , 0 , pad ) , " constant " , 0 )
""" 4 to 1 concat """
x = x . permute ( 0 , 3 , 1 , 2 ) # B, C, H, W
x = F . unfold ( x ,
kernel_size = self . cfg . downsample_ratio ,
stride = self . cfg . downsample_ratio ,
padding = 0 ) # B, C*4, HW // 4
x = x . permute ( 0 , 2 , 1 )
return self . layers ( x )
class DeepseekVL2ProcessingInfo ( BaseProcessingInfo ) :
def get_hf_config ( self ) :
return self . ctx . get_hf_config ( DeepseekVLV2Config )
2025-02-19 21:13:50 +08:00
def get_hf_processor ( self , * * kwargs : object ) :
return self . ctx . get_hf_processor ( DeepseekVLV2Processor , * * kwargs )
2025-01-12 16:17:24 +08:00
def get_supported_mm_limits ( self ) - > Mapping [ str , Optional [ int ] ] :
return { " image " : None }
2025-02-25 22:03:02 +08:00
def get_num_image_tokens ( self ,
* ,
image_width : int ,
image_height : int ,
cropping : bool = True ) - > int :
2025-01-12 16:17:24 +08:00
hf_processor = self . get_hf_processor ( )
image_size = hf_processor . image_size
patch_size = hf_processor . patch_size
downsample_ratio = hf_processor . downsample_ratio
2025-02-25 22:03:02 +08:00
if cropping :
best_width , best_height = hf_processor . select_best_resolution (
( image_width , image_height ) )
num_width_tiles , num_height_tiles = ( best_width / / image_size ,
best_height / / image_size )
else :
num_width_tiles = num_height_tiles = 1
2025-01-12 16:17:24 +08:00
h = w = math . ceil ( ( image_size / / patch_size ) / downsample_ratio )
global_views_tokens = h * ( w + 1 )
local_views_tokens = ( num_height_tiles * h ) * ( num_width_tiles * w + 1 )
return global_views_tokens + local_views_tokens + 1
def get_image_size_with_most_features ( self ) - > ImageSize :
hf_config = self . get_hf_config ( )
candidate_resolutions = hf_config . candidate_resolutions
height , width = max ( candidate_resolutions ,
key = lambda x : self . get_num_image_tokens (
image_width = x [ 1 ] , image_height = x [ 0 ] ) )
return ImageSize ( width = width , height = height )
class DeepseekVL2DummyInputsBuilder (
BaseDummyInputsBuilder [ DeepseekVL2ProcessingInfo ] ) :
2025-04-11 03:32:14 +08:00
def get_dummy_text ( self , mm_counts : Mapping [ str , int ] ) - > str :
num_images = mm_counts . get ( " image " , 0 )
processor = self . info . get_hf_processor ( )
image_token = processor . image_token
return image_token * num_images
def get_dummy_mm_data (
2025-01-12 16:17:24 +08:00
self ,
seq_len : int ,
mm_counts : Mapping [ str , int ] ,
2025-04-11 03:32:14 +08:00
) - > MultiModalDataDict :
2025-01-12 16:17:24 +08:00
num_images = mm_counts . get ( " image " , 0 )
max_image_size = self . info . get_image_size_with_most_features ( )
2025-04-11 03:32:14 +08:00
return {
2025-01-12 16:17:24 +08:00
" image " :
self . _get_dummy_images ( width = max_image_size . width ,
height = max_image_size . height ,
num_images = num_images )
}
class DeepseekVL2MultiModalProcessor (
BaseMultiModalProcessor [ DeepseekVL2ProcessingInfo ] ) :
def _call_hf_processor (
self ,
prompt : str ,
mm_data : Mapping [ str , object ] ,
mm_kwargs : Mapping [ str , object ] ,
2025-06-30 13:26:49 -04:00
tok_kwargs : Mapping [ str , object ] ,
2025-01-12 16:17:24 +08:00
) - > BatchFeature :
if mm_data :
2025-01-18 13:59:39 +08:00
processed_outputs = self . info . ctx . call_hf_processor (
2025-01-12 16:17:24 +08:00
self . info . get_hf_processor ( * * mm_kwargs ) ,
dict ( prompt = prompt , * * mm_data ) ,
2025-06-30 13:26:49 -04:00
dict ( * * mm_kwargs , * * tok_kwargs ) ,
2025-01-12 16:17:24 +08:00
)
2025-05-27 23:45:48 +08:00
pixel_values = processed_outputs [ " pixel_values " ]
2025-01-18 13:59:39 +08:00
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs [ " images_spatial_crop " ]
2025-01-12 16:17:24 +08:00
patches_per_image = [
x . prod ( ) . item ( ) + 1 for x in images_spatial_crop
]
2025-01-18 13:59:39 +08:00
pixel_values = pixel_values . split ( patches_per_image )
processed_outputs [ " pixel_values " ] = pixel_values
2025-01-12 16:17:24 +08:00
else :
tokenizer = self . info . get_tokenizer ( )
processed_outputs = tokenizer ( prompt ,
add_special_tokens = True ,
return_tensors = " pt " )
return processed_outputs
def _get_mm_fields_config (
self ,
hf_inputs : BatchFeature ,
hf_processor_mm_kwargs : Mapping [ str , object ] ,
) - > Mapping [ str , MultiModalFieldConfig ] :
return dict (
pixel_values = MultiModalFieldConfig . batched ( " image " ) ,
images_spatial_crop = MultiModalFieldConfig . batched ( " image " ) ,
image_embeds = MultiModalFieldConfig . batched ( " image " ) ,
)
2025-02-28 01:44:25 +08:00
def _get_prompt_updates (
2025-01-12 16:17:24 +08:00
self ,
mm_items : MultiModalDataItems ,
hf_processor_mm_kwargs : Mapping [ str , object ] ,
out_mm_kwargs : MultiModalKwargs ,
2025-02-28 01:44:25 +08:00
) - > Sequence [ PromptUpdate ] :
2025-01-22 19:08:31 +08:00
hf_processor = self . info . get_hf_processor ( * * hf_processor_mm_kwargs )
image_token_id = hf_processor . image_token_id
assert isinstance ( image_token_id , int )
2025-01-12 16:17:24 +08:00
def get_replacement_deepseek_vl2 ( item_idx : int ) :
images = mm_items . get_items (
" image " , ( ImageEmbeddingItems , ImageProcessorItems ) )
if isinstance ( images , ImageEmbeddingItems ) :
num_image_tokens = images . get_feature_size ( item_idx )
else :
image_size = images . get_image_size ( item_idx )
num_image_tokens = self . info . get_num_image_tokens (
image_width = image_size . width ,
image_height = image_size . height ,
2025-02-25 22:03:02 +08:00
cropping = len ( images ) < = 2 ,
2025-01-12 16:17:24 +08:00
)
return [ image_token_id ] * num_image_tokens
return [
PromptReplacement (
modality = " image " ,
target = [ image_token_id ] ,
replacement = get_replacement_deepseek_vl2 ,
)
]
2025-03-07 18:33:38 +08:00
def _cached_apply_hf_processor (
self ,
prompt : Union [ str , list [ int ] ] ,
mm_data_items : MultiModalDataItems ,
hf_processor_mm_kwargs : Mapping [ str , object ] ,
2025-06-30 13:26:49 -04:00
tokenization_kwargs : Mapping [ str , object ] ,
2025-04-29 09:40:35 +08:00
* ,
return_mm_hashes : bool ,
) - > tuple [ list [ int ] , MultiModalKwargs , Optional [ MultiModalHashes ] , bool ] :
2025-03-07 18:33:38 +08:00
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items . get_count ( " image " , strict = False ) > 2 :
2025-04-29 09:40:35 +08:00
return self . _apply_hf_processor (
2025-03-07 18:33:38 +08:00
prompt = prompt ,
2025-04-29 09:40:35 +08:00
mm_data_items = mm_data_items ,
2025-03-07 18:33:38 +08:00
hf_processor_mm_kwargs = hf_processor_mm_kwargs ,
2025-06-30 13:26:49 -04:00
tokenization_kwargs = tokenization_kwargs ,
2025-04-29 09:40:35 +08:00
return_mm_hashes = return_mm_hashes ,
2025-03-07 18:33:38 +08:00
)
return super ( ) . _cached_apply_hf_processor (
prompt = prompt ,
mm_data_items = mm_data_items ,
hf_processor_mm_kwargs = hf_processor_mm_kwargs ,
2025-06-30 13:26:49 -04:00
tokenization_kwargs = tokenization_kwargs ,
2025-04-29 09:40:35 +08:00
return_mm_hashes = return_mm_hashes ,
2025-03-07 18:33:38 +08:00
)
2025-01-12 16:17:24 +08:00
@MULTIMODAL_REGISTRY.register_processor (
DeepseekVL2MultiModalProcessor ,
info = DeepseekVL2ProcessingInfo ,
dummy_inputs = DeepseekVL2DummyInputsBuilder )
class DeepseekVLV2ForCausalLM ( nn . Module , SupportsMultiModal , SupportsPP ) :
hf_to_vllm_mapper = WeightsMapper ( orig_to_new_prefix = {
" language. " : " language_model. " ,
} )
2025-07-03 16:18:30 +08:00
@classmethod
def get_placeholder_str ( cls , modality : str , i : int ) - > Optional [ str ] :
if modality . startswith ( " image " ) :
return " <image> "
raise ValueError ( " Only image modality is supported " )
2025-01-12 16:17:24 +08:00
def __init__ ( self , * , vllm_config : VllmConfig , prefix : str = " " ) :
super ( ) . __init__ ( )
config : DeepseekVLV2Config = vllm_config . model_config . hf_config
quant_config = vllm_config . quant_config
multimodal_config = vllm_config . model_config . multimodal_config
self . config = config
self . multimodal_config = multimodal_config
self . vision_config = config . vision_config
self . projector_config = config . projector_config
self . text_config = config . text_config
model_config = vllm_config . model_config
2025-02-19 21:13:50 +08:00
tokenizer = cached_tokenizer_from_config ( model_config )
self . image_token_id = tokenizer . vocab [ _IMAGE_TOKEN ]
2025-01-12 16:17:24 +08:00
self . vision = self . _init_vision_module ( self . vision_config ,
quant_config ,
maybe_prefix ( prefix , " vision " ) )
self . projector = MlpProjector ( self . projector_config )
self . tile_tag = config . tile_tag
self . global_view_pos = config . global_view_pos
# special token for image token sequence format
embed_std = 1 / torch . sqrt (
torch . tensor ( self . projector_config . n_embed , dtype = torch . float32 ) )
if self . tile_tag == " 2D " :
2025-07-08 08:07:14 -07:00
# <|view_seperator|>, <|\n|>
2025-01-12 16:17:24 +08:00
self . image_newline = nn . Parameter (
torch . randn ( self . projector_config . n_embed ) * embed_std )
# This is a typo in original implementation
2025-07-08 08:07:14 -07:00
self . view_seperator = nn . Parameter (
2025-01-12 16:17:24 +08:00
torch . randn ( self . projector_config . n_embed ) * embed_std )
else :
raise ValueError (
f " Only 2D tile_tag is supported currently, got: { self . tile_tag } "
)
2025-01-17 01:14:48 +08:00
if self . text_config . topk_method == " noaux_tc " :
architectures = [ " DeepseekV3ForCausalLM " ]
elif not self . text_config . use_mla :
architectures = [ " DeepseekForCausalLM " ]
else :
architectures = [ " DeepseekV2ForCausalLM " ]
2025-01-12 16:17:24 +08:00
self . language_model = init_vllm_registered_model (
vllm_config = vllm_config ,
hf_config = self . text_config ,
prefix = maybe_prefix ( prefix , " language " ) ,
2025-01-17 01:14:48 +08:00
architectures = architectures ,
2025-01-12 16:17:24 +08:00
)
self . make_empty_intermediate_tensors = (
self . language_model . make_empty_intermediate_tensors )
2025-07-25 13:45:16 +08:00
def _get_parent_and_attr ( self , root : torch . nn . Module , dotted_name : str ) :
""" Return (parent_module, final_attr_name) for a dotted module path. """
names = dotted_name . split ( ' . ' )
parent = root
for n in names [ : - 1 ] :
parent = getattr ( parent , n )
return parent , names [ - 1 ]
#patch for timm ViT instance to support tensor parallel
def patch_vit_for_tp ( self , vit : torch . nn . Module ,
quant_config : QuantizationConfig ) :
try :
import timm
except ImportError as e :
raise ImportError ( " Please install timm " ) from e
for name , module in vit . named_modules ( ) :
if isinstance ( module , nn . Linear ) :
parent , attr_name = self . _get_parent_and_attr ( vit , name )
if isinstance ( parent , timm . layers . Mlp ) and attr_name == " fc1 " :
new_linear = replace_linear_class ( module , " colwise " ,
quant_config )
setattr ( parent , attr_name , new_linear )
elif isinstance ( parent ,
timm . layers . Mlp ) and attr_name == " fc2 " :
new_linear = replace_linear_class ( module , " rowwise " ,
quant_config )
setattr ( parent , attr_name , new_linear )
return vit
2025-01-12 16:17:24 +08:00
def _init_vision_module (
self ,
vision_config : VisionEncoderConfig ,
quant_config : Optional [ QuantizationConfig ] ,
prefix : str = " " ,
) - > nn . Module :
# TODO: refactor vision model through timm wrapper from transformers
try :
import timm
2025-07-25 13:45:16 +08:00
except ImportError as e :
raise ImportError ( " Please install timm " ) from e
2025-01-12 16:17:24 +08:00
with set_default_torch_dtype ( torch . float16 ) :
model = timm . create_model (
" vit_so400m_patch14_siglip_384.webli " ,
pretrained = False ,
num_classes = 0 ,
dynamic_img_size = True ,
dynamic_img_pad = True ,
)
2025-07-25 13:45:16 +08:00
if get_tensor_model_parallel_world_size ( ) > 1 :
model = self . patch_vit_for_tp ( model , quant_config )
2025-01-12 16:17:24 +08:00
model = model . to ( dtype = torch . get_default_dtype ( ) )
return model
def _validate_pixel_values (
2025-05-15 06:06:50 +01:00
self , data : Union [ torch . Tensor , list [ torch . Tensor ] ]
) - > Union [ torch . Tensor , list [ torch . Tensor ] ] :
2025-01-12 16:17:24 +08:00
h = w = self . vision_config . image_size
expected_dims = ( 3 , h , w )
def _validate_shape ( d : torch . Tensor ) :
actual_dims = tuple ( d . shape [ 1 : ] )
if actual_dims != expected_dims :
expected_expr = ( " num_patches " , * map ( str , expected_dims ) )
raise ValueError (
" The expected shape of pixel values per image per batch "
f " is { expected_expr } . You supplied { tuple ( d . shape ) } . " )
for d in data :
_validate_shape ( d )
return data
def _validate_images_spatial_crop (
2025-05-15 06:06:50 +01:00
self , data : Union [ torch . Tensor , list [ torch . Tensor ] ]
) - > Union [ torch . Tensor , list [ torch . Tensor ] ] :
2025-01-12 16:17:24 +08:00
expected_dims = 2
def _validate_shape ( d : torch . Tensor ) :
actual_dims = d . size ( - 1 )
if actual_dims != expected_dims :
expected_expr = str ( expected_dims )
raise ValueError (
f " The expected shape of image sizes per image per batch "
f " is { expected_expr } . You supplied { tuple ( d . shape ) } . " )
for d in data :
_validate_shape ( d )
return data
def _parse_and_validate_image_input (
self , * * kwargs : object ) - > Optional [ DeepseekVL2ImageInputs ] :
pixel_values = kwargs . pop ( " pixel_values " , None )
images_spatial_crop = kwargs . pop ( " images_spatial_crop " , None )
image_embeds = kwargs . pop ( " image_embeds " , None )
if pixel_values is None and image_embeds is None :
return None
if pixel_values is not None :
if not isinstance ( pixel_values , ( torch . Tensor , list ) ) :
raise ValueError ( " Incorrect type of pixel values. "
f " Got type: { type ( pixel_values ) } " )
if not isinstance ( images_spatial_crop , ( torch . Tensor , list ) ) :
raise ValueError ( " Incorrect type of image sizes. "
f " Got type: { type ( images_spatial_crop ) } " )
return DeepseekVL2ImagePixelInputs (
type = " pixel_values " ,
data = self . _validate_pixel_values ( flatten_bn ( pixel_values ) ) ,
images_spatial_crop = self . _validate_images_spatial_crop (
flatten_bn ( images_spatial_crop , concat = True ) ) )
if image_embeds is not None :
2025-03-18 20:54:40 +08:00
if not isinstance ( image_embeds , ( torch . Tensor , list ) ) :
2025-01-12 16:17:24 +08:00
raise ValueError ( " Incorrect type of image embeddings. "
f " Got type: { type ( image_embeds ) } " )
return DeepseekVL2VImageEmbeddingInputs (
type = " image_embeds " ,
data = flatten_bn ( image_embeds ) ,
)
raise AssertionError ( " This line should be unreachable. " )
def _pixel_values_to_embedding (
self ,
pixel_values : NestedTensors ,
images_spatial_crop : torch . Tensor ,
) - > NestedTensors :
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
total_tiles = [ x for x in pixel_values ]
# [batch_all_tiles, 3, height, width]
total_tiles = torch . cat ( total_tiles , dim = 0 )
# [batch_all_tiles, vit_seq_len, c]
images_feature = self . vision . forward_features ( total_tiles )
# [batch_all_tiles, hw, D]
images_embeds = self . projector ( images_feature )
_ , hw , n_dim = images_embeds . shape
h = w = int ( hw * * 0.5 )
2025-03-31 01:47:57 +08:00
# fill image token based on self.tile_tag & self.global_view_pos
2025-01-12 16:17:24 +08:00
tile_index = 0
vision_embeddings = [ ]
for jdx in range ( images_spatial_crop . size ( 0 ) ) :
# extra global & local features
num_width_tiles , num_height_tiles = images_spatial_crop [ jdx ]
if num_width_tiles == 0 or num_height_tiles == 0 :
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds [ tile_index ]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds [ tile_index + 1 : tile_index + 1 +
num_tiles_in_image ]
tile_index + = num_tiles_in_image + 1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features . view ( h , w , n_dim )
# [D] -> [h, 1, D]
new_lines_in_global = repeat ( self . image_newline , " d -> h 1 d " , h = h )
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch . cat ( [ global_features , new_lines_in_global ] ,
dim = 1 )
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features . view ( - 1 , n_dim )
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange ( local_features ,
" (th tw) (h w) d -> (th h) (tw w) d " ,
th = num_height_tiles ,
tw = num_width_tiles ,
h = h ,
w = w )
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat ( self . image_newline ,
" d -> (th h) 1 d " ,
th = num_height_tiles ,
h = h )
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch . cat ( [ local_features , new_lines_in_local ] ,
dim = 1 )
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features . view ( - 1 , n_dim )
# merge global and local tiles
if self . global_view_pos == " head " :
global_local_features = torch . cat ( [
global_features ,
2025-07-08 08:07:14 -07:00
self . view_seperator [ None , : ] ,
2025-01-12 16:17:24 +08:00
local_features ,
] )
else :
global_local_features = torch . cat ( [
local_features ,
2025-07-08 08:07:14 -07:00
self . view_seperator [ None , : ] ,
2025-01-12 16:17:24 +08:00
global_features ,
] )
vision_embeddings . append ( global_local_features )
return vision_embeddings
def _process_image_input (
self , image_input : DeepseekVL2ImageInputs ) - > torch . Tensor :
if image_input [ " type " ] == " image_embeds " :
image_data = image_input [ " data " ]
if is_list_of ( image_data , torch . Tensor ) :
# it's already a list of tensors
return image_data
if len ( image_data . shape ) == 3 :
# 3D tensor
return list ( torch . unbind ( image_data , dim = 0 ) )
raise ValueError (
2025-02-25 11:26:12 +09:00
" We expect batched 2D tensors; "
2025-01-12 16:17:24 +08:00
" this can be either a list of 2D tensors or a single 3D tensor. "
)
pixel_values = image_input [ " data " ]
images_spatial_crop = image_input [ " images_spatial_crop " ]
return self . _pixel_values_to_embedding (
pixel_values = pixel_values , images_spatial_crop = images_spatial_crop )
2025-04-09 13:12:54 +02:00
def get_language_model ( self ) - > torch . nn . Module :
return self . language_model
2025-06-16 13:32:15 -04:00
def get_multimodal_embeddings ( self ,
* * kwargs : object ) - > MultiModalEmbeddings :
2025-01-12 16:17:24 +08:00
image_input = self . _parse_and_validate_image_input ( * * kwargs )
if image_input is None :
2025-06-16 13:32:15 -04:00
return [ ]
2025-01-12 16:17:24 +08:00
vision_embeddings = self . _process_image_input ( image_input )
return vision_embeddings
def get_input_embeddings (
self ,
input_ids : torch . Tensor ,
2025-03-14 15:59:56 +08:00
multimodal_embeddings : Optional [ MultiModalEmbeddings ] = None ,
2025-01-12 16:17:24 +08:00
) - > torch . Tensor :
inputs_embeds = self . language_model . get_input_embeddings ( input_ids )
2025-06-18 18:48:29 -04:00
if multimodal_embeddings is not None \
and len ( multimodal_embeddings ) != 0 :
2025-01-12 16:17:24 +08:00
inputs_embeds = merge_multimodal_embeddings (
input_ids , inputs_embeds , multimodal_embeddings ,
self . image_token_id )
return inputs_embeds
def forward ( self ,
input_ids : torch . Tensor ,
positions : torch . Tensor ,
intermediate_tensors : Optional [ IntermediateTensors ] = None ,
inputs_embeds : Optional [ torch . Tensor ] = None ,
* * kwargs : object ) :
if intermediate_tensors is not None :
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility
elif inputs_embeds is None :
vision_embeddings = self . get_multimodal_embeddings ( * * kwargs )
inputs_embeds = self . get_input_embeddings ( input_ids ,
vision_embeddings )
input_ids = None
hidden_states = self . language_model ( input_ids ,
positions ,
intermediate_tensors ,
inputs_embeds = inputs_embeds )
return hidden_states
def compute_logits (
self ,
hidden_states : torch . Tensor ,
sampling_metadata : SamplingMetadata ,
) - > Optional [ torch . Tensor ] :
return self . language_model . compute_logits ( hidden_states ,
sampling_metadata )
2025-05-15 06:06:50 +01:00
def load_weights ( self , weights : Iterable [ tuple [ str ,
torch . Tensor ] ] ) - > set [ str ] :
2025-01-12 16:17:24 +08:00
loader = AutoWeightsLoader ( self )
autoloaded_weights = loader . load_weights ( weights ,
mapper = self . hf_to_vllm_mapper )
return autoloaded_weights