2025-04-15 05:41:48 +08:00
# SPDX-License-Identifier: Apache-2.0
2025-06-03 11:20:17 -07:00
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2025-04-15 05:41:48 +08:00
# ruff: noqa: E501
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
# This file is meant to be used in kimi_vl.py only
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
#
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
#
# Licensing Information:
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
# - Other parts of the code are licensed under the MIT License.
#
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
2025-05-15 06:06:50 +01:00
from collections . abc import Sequence
2025-04-15 05:41:48 +08:00
from copy import deepcopy
from functools import cached_property
2025-05-15 06:06:50 +01:00
from typing import Optional , Union
2025-04-15 05:41:48 +08:00
import torch
import torch . nn as nn
import torch . nn . functional as F
from transformers . activations import ACT2FN , PytorchGELUTanh
from transformers . modeling_utils import PreTrainedModel
from transformers . utils import is_flash_attn_2_available
2025-09-02 00:56:56 +08:00
from vllm . model_executor . layers . linear import ReplicatedLinear
from vllm . model_executor . models . utils import maybe_prefix
2025-04-15 05:41:48 +08:00
from vllm . transformers_utils . configs . moonvit import MoonViTConfig
if is_flash_attn_2_available ( ) :
from flash_attn import flash_attn_varlen_func
else :
flash_attn_varlen_func = None
def multihead_attention (
q : torch . Tensor ,
k : torch . Tensor ,
v : torch . Tensor ,
q_cu_seqlens : Optional [ torch . Tensor ] = None ,
k_cu_seqlens : Optional [ torch . Tensor ] = None ,
) :
""" Multi-head attention using flash attention 2.
Args :
q , k , v : tensor of shape ( batch_size , seqlen , num_heads , head_dim ) ,
or ( tot_seqlens , num_heads , head_dim ) if packing .
q_cu_seqlens ( torch . Tensor ) : cumulative sequence lengths of q .
The first element should be 0 and the last element should be q . shape [ 0 ] .
k_cu_seqlens ( torch . Tensor ) : cumulative sequence lengths of k .
The first element should be 0 and the last element should be k . shape [ 0 ] .
Returns :
output : shape ( batch_size , seqlen , dim ) or ( tot_seqlens , dim ) if packing ,
where dim = num_heads * head_dim
"""
# Unified format legal check
assert q . dim ( ) == k . dim ( ) == v . dim ( ) == 3 , " q, k, v must have 3 dims "
assert q_cu_seqlens [ - 1 ] == q . shape [
0 ] , " q_cu_seqlens must sum to q.shape[0] "
assert ( k_cu_seqlens [ - 1 ] == k . shape [ 0 ] ==
v . shape [ 0 ] ) , " k_cu_seqlens must sum to k.shape[0] "
assert q . dtype in [
torch . bfloat16 ,
torch . float16 ,
] , f " unsupported dtype { q . dtype } for multihead attn "
max_seqlen_q = ( q_cu_seqlens [ 1 : ] - q_cu_seqlens [ : - 1 ] ) . max ( ) . item ( )
max_seqlen_k = ( k_cu_seqlens [ 1 : ] - k_cu_seqlens [ : - 1 ] ) . max ( ) . item ( )
attn_out = flash_attn_varlen_func (
q ,
k ,
v ,
q_cu_seqlens ,
k_cu_seqlens ,
max_seqlen_q ,
max_seqlen_k ,
causal = False ,
)
attn_out = attn_out . flatten ( start_dim = - 2 )
return attn_out
def sdpa_attention (
q : torch . Tensor ,
k : torch . Tensor ,
v : torch . Tensor ,
q_cu_seqlens : Optional [ torch . Tensor ] = None ,
k_cu_seqlens : Optional [ torch . Tensor ] = None ,
) - > torch . Tensor :
""" SDPA attention.
Args :
q , k , v : tensor of shape ( batch_size , seqlen , num_heads , head_dim ) ,
or ( tot_seqlens , num_heads , head_dim ) if packing .
"""
seq_length = q . shape [ 0 ]
attention_mask = torch . zeros ( [ 1 , seq_length , seq_length ] ,
device = q . device ,
dtype = torch . bool )
for i in range ( 1 , len ( q_cu_seqlens ) ) :
attention_mask [
. . . ,
q_cu_seqlens [ i - 1 ] : q_cu_seqlens [ i ] ,
q_cu_seqlens [ i - 1 ] : q_cu_seqlens [ i ] ,
] = True
q = q . transpose ( 0 , 1 )
k = k . transpose ( 0 , 1 )
v = v . transpose ( 0 , 1 )
attn_output = F . scaled_dot_product_attention ( q ,
k ,
v ,
attention_mask ,
dropout_p = 0.0 )
attn_output = attn_output . transpose ( 0 , 1 )
attn_output = attn_output . reshape ( seq_length , - 1 )
return attn_output
VL_VISION_ATTENTION_FUNCTIONS = {
" flash_attention_2 " : multihead_attention ,
" sdpa " : sdpa_attention ,
}
def _apply_rope_input_validation ( x , freqs_cis ) :
assert x . ndim == freqs_cis . ndim + 1 , ( x . shape , freqs_cis . shape )
assert x . shape [ : - 2 ] == freqs_cis . shape [ : - 1 ] , ( x . shape , freqs_cis . shape )
assert x . shape [ - 1 ] == 2 * freqs_cis . shape [ - 1 ] , ( x . shape , freqs_cis . shape )
assert freqs_cis . dtype == torch . complex64 , freqs_cis . dtype
def apply_rope ( xq : torch . Tensor , xk : torch . Tensor ,
freqs_cis : torch . Tensor ) - > tuple [ torch . Tensor , torch . Tensor ] :
"""
Args : ( The leading dimensions of all inputs should be the same )
xq : query , tensor of shape ( . . . , num_heads , head_dim )
xk : key , tensor of shape ( . . . , num_heads , head_dim )
freqs_cis : tensor of shape ( . . . , head_dim / 2 ) , dtype = torch . complex64 . It contains the precomputed cis ( freqs ) for each position in the 2 D grid .
Returns :
xq_out , xk_out : tensors of shape ( . . . , num_heads , head_dim )
"""
_apply_rope_input_validation ( xq , freqs_cis )
_apply_rope_input_validation ( xk , freqs_cis )
freqs_cis = freqs_cis . unsqueeze ( - 2 ) # ..., 1, head_dim/2
# ..., num_heads, head_dim/2
xq_ = torch . view_as_complex ( xq . float ( ) . view ( * xq . shape [ : - 1 ] , - 1 , 2 ) )
xk_ = torch . view_as_complex ( xk . float ( ) . view ( * xq . shape [ : - 1 ] , - 1 , 2 ) )
xq_out = torch . view_as_real ( xq_ * freqs_cis ) . flatten (
- 2 ) # ..., num_heads, head_dim
xk_out = torch . view_as_real ( xk_ * freqs_cis ) . flatten (
- 2 ) # ..., num_heads, head_dim
return xq_out . type_as ( xq ) , xk_out . type_as ( xk )
class Learnable2DInterpPosEmb ( nn . Module ) :
def __init__ ( self ,
height : int ,
width : int ,
dim : int ,
interpolation_mode : str = " bicubic " ) - > None :
super ( ) . __init__ ( )
self . height = height
self . width = width
self . interpolation_mode = interpolation_mode
self . weight = nn . Parameter ( torch . empty ( height , width , dim ) )
self . reset_parameters ( )
def reset_parameters ( self ) :
nn . init . normal_ ( self . weight )
def forward ( self , x : torch . Tensor , grid_hws : torch . Tensor ) - > torch . Tensor :
pos_embs = [ ]
for shape in grid_hws . tolist ( ) :
if shape == self . weight . shape [ : - 1 ] :
pos_embs . append ( self . weight . flatten ( end_dim = 1 ) )
else :
pos_embs . append (
F . interpolate (
self . weight . permute ( ( 2 , 0 , 1 ) ) . unsqueeze ( 0 ) ,
size = shape ,
mode = self . interpolation_mode ,
) . squeeze ( 0 ) . permute ( ( 1 , 2 , 0 ) ) . flatten ( end_dim = 1 ) )
out = x + torch . cat ( pos_embs )
return out
class MoonVisionPatchEmbed ( nn . Module ) :
def __init__ (
self ,
out_dim : int ,
in_dim : int = 3 ,
2025-05-15 06:06:50 +01:00
patch_size : Union [ int , tuple [ int , int ] ] = ( 14 , 14 ) ,
2025-04-15 05:41:48 +08:00
pos_emb_height : int = 14 ,
pos_emb_width : int = 14 ,
) :
super ( ) . __init__ ( )
assert isinstance (
patch_size ,
( int , Sequence ) ) , f " Invalid patch_size type: { type ( patch_size ) } "
if isinstance ( patch_size , int ) :
patch_size = ( patch_size , patch_size )
assert ( len ( patch_size ) == 2
) , f " Expected patch_size to be a tuple of 2, got { patch_size } "
self . patch_size = patch_size
self . proj = nn . Conv2d ( in_dim ,
out_dim ,
kernel_size = patch_size ,
stride = patch_size )
self . pos_emb = Learnable2DInterpPosEmb ( height = pos_emb_height ,
width = pos_emb_width ,
dim = out_dim )
def forward ( self , x : torch . Tensor , grid_hw : torch . Tensor ) - > torch . Tensor :
"""
Args :
x ( L , Channels ) : input tensor
grid_hw ( N , 2 ) : grid height and width
Returns :
( L , Cout ) tensor
"""
x = self . proj ( x ) . view ( x . size ( 0 ) , - 1 )
# apply positional embedding
x = self . pos_emb ( x , grid_hw )
return x
class Rope2DPosEmb ( nn . Module ) :
""" 2D rotary position embedding with multi-resolution support.
This class is intended to be used in the following way :
1. Before training , create an instance of Rope2DPosEmb . This instance will hold the precomputed cis .
2. Before each forward pass , call ` get_freqs_cis_by_ * ` to get the ` freqs_cis ` tensor for this iteration .
3. During the forward pass , pass the ` freqs_cis ` tensor to each attention layer , and call ` apply ` just before each attention operation .
The rope is shared across all attention layers and all heads .
Refs :
- RoFormer : https : / / arxiv . org / abs / 2104.09864
- VisionLLaMA : https : / / arxiv . org / abs / 2403.00522
- https : / / github . com / Meituan - AutoML / VisionLLaMA / blob / main / dit / models . py
Args :
dim ( int ) : usually the multi - head attention dimension , should be divisible by 4 ( TODO : relax this constraint if needed )
max_height ( int ) : the maximum height of the 2 D grid
max_width ( int ) : the maximum width of the 2 D grid
theta_base ( float ) : the base of the theta
device ( str ) : the device to store the precomputed cis
"""
def __init__ ( self ,
dim : int ,
max_height : int ,
max_width : int ,
theta_base = 10000 ,
device = " cuda " ) :
super ( ) . __init__ ( )
self . dim = dim
assert self . dim % 4 == 0 , " dim must be divisible by 4 "
self . max_height = max_height
self . max_width = max_width
self . theta_base = theta_base
self . device = device
def extra_repr ( self ) :
return f " dim= { self . dim } , max_height= { self . max_height } , max_width= { self . max_width } , theta_base= { self . theta_base } "
@cached_property
def precomputed_freqs_cis ( self ) - > torch . Tensor :
""" Calculate the cis(freqs) for each position in the 2D grid.
Return : complex tensor of shape ( max_height , max_width , dim / / 2 ) and value :
height axis : ret [ h , w , 2 * i ] = cis ( h * theta_base * * ( - 4 * i / dim ) )
weight axis : ret [ h , w , 2 * i + 1 ] = cis ( w * theta_base * * ( - 4 * i / dim ) ) with ( i in [ 0 , dim / / 4 ) )
note : ` cis ` is a mathematical notation defined by cis x = cos x + i sin x ,
"""
N = self . max_height * self . max_width
flat_pos = torch . arange ( 0 , N ) . float ( ) . to ( self . device )
x_pos = flat_pos % self . max_width
y_pos = flat_pos / / self . max_width
dim_range = ( torch . arange ( 0 , self . dim ,
4 ) [ : ( self . dim / / 4 ) ] . float ( ) . to ( self . device )
) # C/4
freqs = 1.0 / ( self . theta_base * * ( dim_range / self . dim ) )
x_freqs = torch . outer ( x_pos , freqs ) . float ( ) # N, C/4
y_freqs = torch . outer ( y_pos , freqs ) . float ( ) # N, C/4
x_cis = torch . polar ( torch . ones_like ( x_freqs ) , x_freqs ) # N, C/4
y_cis = torch . polar ( torch . ones_like ( y_freqs ) , y_freqs ) # N, C/4
# N, C/4, 2
freqs_cis = torch . cat (
[ x_cis . unsqueeze ( dim = - 1 ) ,
y_cis . unsqueeze ( dim = - 1 ) ] , dim = - 1 )
# max_height, max_width, C/2
freqs_cis = freqs_cis . reshape ( self . max_height , self . max_width , - 1 )
return freqs_cis
def get_freqs_cis_by_seqlens ( self , grid_hws : torch . Tensor ) - > torch . Tensor :
"""
Args :
grid_hws ( torch . Tensor ) : containing list of ( height , width ) or ( t , height , width ) tuples .
Returns :
freqs_cis : tensor of shape ( sum ( t * height * width ) , dim / / 2 )
"""
shapes = grid_hws . tolist ( )
assert all ( 1 < = h < = self . max_height and 1 < = w < = self . max_width
for h , w in shapes ) , (
shapes ,
self . max_height ,
self . max_width ,
)
freqs_cis = torch . cat (
[
self . precomputed_freqs_cis [ : h , : w ] . reshape ( - 1 , self . dim / / 2 )
for h , w in shapes
] ,
dim = 0 ,
)
return freqs_cis
def get_freqs_cis_by_idx ( self , pos_idx : torch . Tensor ,
pos_idx_mask : torch . Tensor ) - > torch . Tensor :
"""
Args :
pos_idx : tensor of shape ( . . . , 2 ) , It contains the ( h , w ) position indices of each 2 D token .
pos_idx_mask : a mask of shape ( . . . ) , the leading dimensions should be the same as pos_idx .
Rope will only be applied to the tokens with True mask . ` freqs_cis ` for the tokens with False mask with be ones .
Return :
freqs_cis : tensor of shape ( . . . , dim / / 2 )
"""
assert ( pos_idx . shape [ : - 1 ] == pos_idx_mask . shape
and pos_idx . shape [ - 1 ] == 2 and pos_idx . ndim
== pos_idx_mask . ndim + 1 ) , ( pos_idx . shape , pos_idx_mask . shape )
assert pos_idx_mask . dtype == torch . bool , pos_idx_mask . dtype
shp = pos_idx_mask . shape + ( self . dim / / 2 , ) # ..., head_dim/2
freqs_cis = torch . ones ( shp , dtype = torch . complex64 ,
device = self . device ) # ..., head_dim/2
freqs_cis [ pos_idx_mask ] = self . precomputed_freqs_cis [ pos_idx [
. . . , 0 ] [ pos_idx_mask ] , pos_idx [ . . . , 1 ] [ pos_idx_mask ] ]
return freqs_cis
class MLP2 ( nn . Module ) :
"""
Args :
dims : [ in_dim , hidden_dim , out_dim ]
bias : whether to use bias in linear layer .
"""
2025-09-02 00:56:56 +08:00
def __init__ ( self ,
dims : list [ int ] ,
activation ,
bias = True ,
prefix : str = " " ,
use_data_parallel : bool = False ) :
2025-04-15 05:41:48 +08:00
super ( ) . __init__ ( )
assert len ( dims ) == 3
2025-09-02 00:56:56 +08:00
self . use_data_parallel = use_data_parallel
self . fc0 = ReplicatedLinear ( dims [ 0 ] ,
dims [ 1 ] ,
bias = bias ,
prefix = maybe_prefix ( prefix , " fc0 " ) )
self . fc1 = ReplicatedLinear ( dims [ 1 ] ,
dims [ 2 ] ,
bias = bias ,
prefix = maybe_prefix ( prefix , " fc1 " ) )
2025-04-15 05:41:48 +08:00
self . activation = activation
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
2025-09-02 00:56:56 +08:00
x , _ = self . fc0 ( x )
2025-04-15 05:41:48 +08:00
x = self . activation ( x )
2025-09-02 00:56:56 +08:00
x , _ = self . fc1 ( x )
return x
2025-04-15 05:41:48 +08:00
class MoonVitEncoderLayer ( nn . Module ) :
def __init__ (
self ,
num_heads : int ,
hidden_dim : int ,
mlp_dim : int ,
2025-09-02 00:56:56 +08:00
prefix : str = " " ,
use_data_parallel : bool = False ,
2025-04-15 05:41:48 +08:00
* ,
attn_implementation : str = " sdpa " ,
activation = F . gelu ,
attn_bias : bool = False ,
) :
super ( ) . __init__ ( )
self . num_heads = num_heads
self . hidden_dim = hidden_dim
self . hidden_size_per_attention_head = self . hidden_dim / / self . num_heads
self . attn_implementation = attn_implementation
# use fa2 in vllm by default
if is_flash_attn_2_available ( ) :
self . attn_implementation = " flash_attention_2 "
self . norm0 = nn . LayerNorm ( hidden_dim )
self . norm1 = nn . LayerNorm ( hidden_dim )
2025-09-02 00:56:56 +08:00
self . use_data_parallel = use_data_parallel
self . mlp = MLP2 ( [ hidden_dim , mlp_dim , hidden_dim ] ,
activation ,
prefix = f " { prefix } .mlp " ,
use_data_parallel = use_data_parallel )
self . wqkv = ReplicatedLinear ( hidden_dim ,
hidden_dim * 3 ,
bias = attn_bias ,
prefix = f " { prefix } .wqkv " )
self . wo = ReplicatedLinear ( hidden_dim ,
hidden_dim ,
bias = attn_bias ,
prefix = f " { prefix } .wo " )
2025-04-15 05:41:48 +08:00
def attention_qkvpacked (
self ,
x : torch . Tensor ,
cu_seqlens : torch . Tensor ,
rope_freqs_cis : Optional [ torch . Tensor ] = None ,
) :
"""
Args :
x ( torch . Tensor ) : ( batch_size , seqlen , hidden_dim )
cu_seqlens ( torch . Tensor ) :
"""
2025-09-02 00:56:56 +08:00
xqkv , _ = self . wqkv ( x )
2025-04-15 05:41:48 +08:00
qkv_shape = xqkv . size ( ) [ : - 1 ] + (
3 ,
self . num_heads ,
self . hidden_size_per_attention_head ,
)
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
xqkv = xqkv . view ( * qkv_shape )
xq , xk , xv = torch . unbind ( xqkv , dim = - 3 )
xq , xk = apply_rope ( xq , xk , rope_freqs_cis )
attn_func = VL_VISION_ATTENTION_FUNCTIONS [ self . attn_implementation ]
attn_out = attn_func ( xq ,
xk ,
xv ,
q_cu_seqlens = cu_seqlens ,
k_cu_seqlens = cu_seqlens )
2025-09-02 00:56:56 +08:00
attn_out , _ = self . wo ( attn_out )
2025-04-15 05:41:48 +08:00
return attn_out
def forward (
self ,
hidden_states : torch . Tensor ,
cu_seqlens : torch . Tensor ,
rope_freqs_cis : Union [ torch . Tensor , None ] = None ,
) - > torch . Tensor :
"""
Args :
hidden_states : non - packed ( B , N , D ) or packed ( L , D ) . if non - packed , seqlens should be None , if packed , seqlens should be set
Returns :
output : same shape of input , non - packed ( B , N , D ) for non - packed input , ( L , D ) for packed input
"""
residual = hidden_states
hidden_states = self . norm0 ( hidden_states )
attn_out = self . attention_qkvpacked ( hidden_states ,
cu_seqlens ,
rope_freqs_cis = rope_freqs_cis )
hidden_states = residual + attn_out
residual = hidden_states
hidden_states = self . mlp ( self . norm1 ( hidden_states ) )
hidden_states = residual + hidden_states
return hidden_states
class MoonVitEncoder ( nn . Module ) :
def __init__ (
self ,
hidden_dim : int ,
num_layers : int ,
block_cfg : dict ,
2025-09-02 00:56:56 +08:00
prefix : str = " " ,
use_data_parallel : bool = False ,
2025-04-15 05:41:48 +08:00
) - > None :
super ( ) . __init__ ( )
self . rope_2d = Rope2DPosEmb (
block_cfg [ " hidden_dim " ] / / block_cfg [ " num_heads " ] , 512 , 512 )
self . blocks = nn . ModuleList (
2025-09-02 00:56:56 +08:00
[ MoonVitEncoderLayer ( use_data_parallel = use_data_parallel , \
prefix = f " { prefix } .blocks. { layer_idx } " , \
* * block_cfg ) for layer_idx in range ( num_layers ) ] )
2025-04-15 05:41:48 +08:00
self . final_layernorm = nn . LayerNorm ( hidden_dim )
def forward ( self , hidden_states : torch . Tensor ,
grid_hw : torch . Tensor ) - > torch . Tensor :
rope_freqs_cis = self . rope_2d . get_freqs_cis_by_seqlens (
grid_hws = grid_hw )
2025-09-02 00:56:56 +08:00
lengths = torch . cat (
( torch . zeros ( 1 , device = hidden_states . device , dtype = grid_hw . dtype ) ,
( grid_hw [ : , 0 ] * grid_hw [ : , 1 ] ) . to ( hidden_states . device ) ) )
2025-04-15 05:41:48 +08:00
cu_seqlens = lengths . cumsum ( dim = 0 , dtype = torch . int32 )
for _ , block in enumerate ( self . blocks ) :
hidden_states = block ( hidden_states ,
cu_seqlens ,
rope_freqs_cis = rope_freqs_cis )
hidden_states = self . final_layernorm ( hidden_states )
return hidden_states
def patch_merger (
x : torch . Tensor ,
grid_hw : torch . Tensor ,
merge_kernel_size : list [ int , int ] = ( 2 , 2 ) ,
2025-05-15 06:06:50 +01:00
) - > list [ torch . Tensor ] :
2025-04-15 05:41:48 +08:00
d_model = x . size ( - 1 )
outputs = [ ]
pre_sum = 0
for x_shape in grid_hw . tolist ( ) :
height , width = x_shape [ 0 ] , x_shape [ 1 ]
# Get the current sequence
seq = x [ pre_sum : pre_sum + height * width ]
# Reshape along self.merge_kernel_size and concat to the last dimension
kernel_height , kernel_width = merge_kernel_size
new_height , new_width = height / / kernel_height , width / / kernel_width
reshaped_seq = seq . view ( new_height , kernel_height , new_width ,
kernel_width , d_model )
reshaped_seq = reshaped_seq . permute ( 0 , 2 , 1 , 3 , 4 ) . contiguous ( )
padded_seq = reshaped_seq . view ( new_height * new_width ,
kernel_height * kernel_width , - 1 )
outputs . append ( padded_seq )
pre_sum + = height * width
return outputs
class MoonVitVLProjector ( nn . Module ) :
def __init__ (
self ,
in_channels : int ,
merge_kernel_size : list [ int , int ] ,
hidden_act : str = " gelu " ,
ln_eps : float = 1e-5 ,
out_dim : int = 4096 ,
) :
super ( ) . __init__ ( )
self . hidden_size = in_channels * merge_kernel_size [
0 ] * merge_kernel_size [ 1 ]
self . pre_norm = nn . nn . LayerNorm ( in_channels , eps = ln_eps )
self . linear_1 = nn . Linear ( self . hidden_size ,
self . hidden_size ,
bias = True )
self . act = ACT2FN [ hidden_act ]
self . linear_2 = nn . Linear ( self . hidden_size , out_dim , bias = True )
def forward ( self , hidden_states : torch . Tensor ) - > torch . Tensor :
hidden_states = self . pre_norm ( hidden_states ) . view ( - 1 , self . hidden_size )
hidden_states = self . linear_1 ( hidden_states )
hidden_states = self . act ( hidden_states )
hidden_states = self . linear_2 ( hidden_states )
return hidden_states
class MoonVitPretrainedModel ( PreTrainedModel ) :
config_class = MoonViTConfig
model_type = " moonvit "
_no_split_modules = [ " PackingTransformer " ]
_supports_flash_attn_2 = True
_supports_sdpa = True
2025-09-02 00:56:56 +08:00
def __init__ ( self ,
config : MoonViTConfig ,
use_data_parallel : bool = False ,
prefix : str = " " ,
* inputs ,
* * kwargs ) :
2025-04-15 05:41:48 +08:00
super ( ) . __init__ ( config , * inputs , * * kwargs )
config = deepcopy ( config )
2025-09-02 00:56:56 +08:00
self . use_data_parallel = use_data_parallel
2025-04-15 05:41:48 +08:00
self . merge_kernel_size = config . merge_kernel_size
2025-09-02 00:56:56 +08:00
self . hidden_size = config . hidden_size
2025-04-15 05:41:48 +08:00
self . patch_size = config . patch_size
2025-09-02 00:56:56 +08:00
self . vit_processing_type = " rope_2d "
2025-04-15 05:41:48 +08:00
self . patch_embed = MoonVisionPatchEmbed (
out_dim = config . hidden_size ,
patch_size = config . patch_size ,
pos_emb_height = config . init_pos_emb_height ,
pos_emb_width = config . init_pos_emb_width ,
)
self . encoder = MoonVitEncoder (
hidden_dim = config . hidden_size ,
num_layers = config . num_hidden_layers ,
block_cfg = {
" num_heads " : config . num_attention_heads ,
" hidden_dim " : config . hidden_size ,
" mlp_dim " : config . intermediate_size ,
" activation " : PytorchGELUTanh ( ) ,
" attn_bias " : True ,
" attn_implementation " : config . _attn_implementation ,
} ,
2025-09-02 00:56:56 +08:00
prefix = f " { prefix } .encoder " ,
2025-04-15 05:41:48 +08:00
)
def forward ( self , pixel_values : torch . Tensor ,
grid_hw : torch . Tensor ) - > torch . Tensor :
"""
Args :
pixel_values ( torch . Tensor ) : The input pixel values .
grid_hw ( torch . Tensor ) : The grid height and width .
Returns :
torch . Tensor : The output tokens .
"""
hidden_states = self . patch_embed ( pixel_values , grid_hw )
hidden_states = self . encoder ( hidden_states , grid_hw )
hidden_states = patch_merger ( hidden_states ,
grid_hw ,
merge_kernel_size = self . merge_kernel_size )
return hidden_states