[Bugfix] Clean up MiniCPM-V (#6939)
Co-authored-by: hezhihui <hzh7269@modelbest.cn> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -20,32 +20,34 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-V-2 model compatible with HuggingFace weights."""
|
||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.types
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.idefics2.modeling_idefics2 import (
|
||||
Idefics2VisionTransformer)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsVision
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.minicpm import MiniCPMModel
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import (cached_get_image_processor,
|
||||
@@ -53,12 +55,12 @@ from vllm.multimodal.image import (cached_get_image_processor,
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
"language_model.model": "language_model",
|
||||
"llm.lm_head": "lm_head",
|
||||
"llm.model": "llm",
|
||||
}
|
||||
|
||||
|
||||
def get_abs_pos(abs_pos, tgt_size):
|
||||
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
|
||||
# abs_pos: L, C
|
||||
# tgt_size: (H, W)
|
||||
# return: M, C
|
||||
@@ -75,10 +77,10 @@ def get_abs_pos(abs_pos, tgt_size):
|
||||
|
||||
|
||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||
def get_2d_sincos_pos_embed(embed_dim,
|
||||
grid_size,
|
||||
cls_token=False,
|
||||
version=2.0):
|
||||
def get_2d_sincos_pos_embed(embed_dim: int,
|
||||
grid_size: Union[int, Tuple[int, int]],
|
||||
cls_token: bool = False,
|
||||
version: Tuple[int, int] = (2, 0)):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
@@ -95,7 +97,7 @@ def get_2d_sincos_pos_embed(embed_dim,
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
if version == 2.0:
|
||||
if version == (2, 0):
|
||||
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
||||
if cls_token:
|
||||
@@ -106,7 +108,9 @@ def get_2d_sincos_pos_embed(embed_dim,
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0):
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
|
||||
grid: Union[int, Tuple[int, int]],
|
||||
version: Tuple[int, int] = (2, 0)):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
@@ -115,14 +119,16 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0):
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
|
||||
|
||||
if version == 2.0:
|
||||
if version == (2, 0):
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
else:
|
||||
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0):
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
|
||||
pos: int,
|
||||
version: Tuple[int, int] = (2, 0)):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,) / (H, W)
|
||||
@@ -133,7 +139,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0):
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
if version == 2.0:
|
||||
if version == (2, 0):
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
@@ -158,19 +164,19 @@ class Resampler(nn.Module):
|
||||
default_norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
def __init__(self,
|
||||
num_queries,
|
||||
grid_size,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kv_dim=None,
|
||||
norm_layer=default_norm_layer,
|
||||
adaptive=False,
|
||||
max_size=(70, 70),
|
||||
version=2.0):
|
||||
num_queries: int,
|
||||
grid_size: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: Optional[int] = None,
|
||||
norm_layer: nn.Module = default_norm_layer,
|
||||
adaptive: bool = False,
|
||||
max_size: Tuple[int, int] = (70, 70),
|
||||
version: Tuple[int, int] = (2, 0)):
|
||||
super().__init__()
|
||||
|
||||
self.version = version
|
||||
if self.version == 2.0:
|
||||
if self.version == (2, 0):
|
||||
self.num_queries = grid_size**2
|
||||
else:
|
||||
self.num_queries = num_queries
|
||||
@@ -195,7 +201,7 @@ class Resampler(nn.Module):
|
||||
self.proj = nn.Parameter(
|
||||
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
|
||||
|
||||
if self.version == 2.0:
|
||||
if self.version == (2, 0):
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.from_numpy(
|
||||
get_2d_sincos_pos_embed(
|
||||
@@ -206,14 +212,17 @@ class Resampler(nn.Module):
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _set_2d_pos_cache(self, max_size, device='cpu'):
|
||||
def _set_2d_pos_cache(self,
|
||||
max_size: Tuple[int, int],
|
||||
device: torch.types.Device = 'cpu'):
|
||||
pos_embed = torch.from_numpy(
|
||||
get_2d_sincos_pos_embed(self.embed_dim,
|
||||
max_size,
|
||||
version=self.version)).float().to(device)
|
||||
self.register_buffer("pos_embed", pos_embed, persistent=False)
|
||||
|
||||
def _adjust_pos_cache(self, tgt_sizes, device):
|
||||
def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
|
||||
device: torch.types.Device):
|
||||
max_h = torch.max(tgt_sizes[:, 0])
|
||||
max_w = torch.max(tgt_sizes[:, 1])
|
||||
if max_h > self.max_size[0] or max_w > self.max_size[1]:
|
||||
@@ -223,7 +232,7 @@ class Resampler(nn.Module):
|
||||
]
|
||||
self._set_2d_pos_cache(self.max_size, device)
|
||||
|
||||
def _init_weights(self, m):
|
||||
def _init_weights(self, m: nn.Module):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
@@ -232,7 +241,9 @@ class Resampler(nn.Module):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward_2_5(self, x, tgt_sizes=None):
|
||||
def forward_2_5(self,
|
||||
x: torch.Tensor,
|
||||
tgt_sizes: Optional[torch.Tensor] = None):
|
||||
assert x.shape[0] == tgt_sizes.shape[0]
|
||||
bs = x.shape[0]
|
||||
|
||||
@@ -278,7 +289,10 @@ class Resampler(nn.Module):
|
||||
x = x @ self.proj
|
||||
return x
|
||||
|
||||
def forward_2(self, x, tgt_sizes=None, attn_mask=None):
|
||||
def forward_2(self,
|
||||
x: torch.Tensor,
|
||||
tgt_sizes: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None):
|
||||
if self.adaptive:
|
||||
pos_embed = torch.Tensor(
|
||||
get_2d_sincos_pos_embed(self.embed_dim,
|
||||
@@ -302,8 +316,11 @@ class Resampler(nn.Module):
|
||||
x = x @ self.proj
|
||||
return x
|
||||
|
||||
def forward(self, x, tgt_sizes=None, attn_mask=None):
|
||||
if self.version == 2.0:
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
tgt_sizes: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None):
|
||||
if self.version == (2, 0):
|
||||
return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask)
|
||||
else:
|
||||
return self.forward_2_5(x, tgt_sizes=tgt_sizes)
|
||||
@@ -322,7 +339,7 @@ def dummy_seq_data_for_minicpmv(seq_len: int):
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_image_for_minicpmv(hf_config):
|
||||
def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
|
||||
width = height = hf_config.image_size
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image}
|
||||
@@ -381,7 +398,7 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
@@ -390,30 +407,48 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.version = float(self.config.version)
|
||||
if not hasattr(self.config, "version"):
|
||||
if self.config.hidden_size == 2304 and self.config.query_num == 64:
|
||||
self.version = (2, 0)
|
||||
else:
|
||||
self.version = (2, 5)
|
||||
else:
|
||||
self.version = str(self.config.version).split(".")
|
||||
self.version = tuple([int(x) for x in self.version])
|
||||
self.llm = self.init_llm(config, cache_config, quant_config)
|
||||
self.vpm = self.init_vision_module()
|
||||
param_dtype = torch.get_default_dtype()
|
||||
self.vpm.to(dtype=param_dtype)
|
||||
self.vision_dim = self.vpm.embed_dim if self.version == 2.0 \
|
||||
self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \
|
||||
else self.vpm.embeddings.embed_dim
|
||||
self.embed_dim = self.llm.config.hidden_size
|
||||
self.embed_dim = self.config.hidden_size
|
||||
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
||||
self.resampler.to(device="cuda", dtype=param_dtype)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def init_llm(self, config, cache_config, quant_config):
|
||||
if self.version == 2.0:
|
||||
return MiniCPMForCausalLM(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
def init_llm(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
if self.version == (2, 0):
|
||||
return MiniCPMModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
elif self.version == (2, 5):
|
||||
return LlamaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
return LlamaForCausalLM(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
return Qwen2Model(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def init_vision_module(self):
|
||||
if self.version == 2.0:
|
||||
if self.version == (2, 0):
|
||||
try:
|
||||
import timm
|
||||
except ImportError:
|
||||
@@ -433,16 +468,30 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.blocks = model.blocks[:-1]
|
||||
else:
|
||||
elif self.version == (2, 5):
|
||||
from transformers.models.idefics2.modeling_idefics2 import (
|
||||
Idefics2VisionTransformer)
|
||||
model = Idefics2VisionTransformer(self.config.vision_config)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
else:
|
||||
from vllm.model_executor.models.na_vit import (
|
||||
SiglipVisionTransformer)
|
||||
if self.config._attn_implementation == 'flash_attention_2':
|
||||
self.config.vision_config._attn_implementation \
|
||||
= 'flash_attention_2'
|
||||
else:
|
||||
# not support sdpa
|
||||
self.config.vision_config._attn_implementation = 'eager'
|
||||
model = SiglipVisionTransformer(self.config.vision_config)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
return model
|
||||
|
||||
def init_resampler(self, embed_dim, vision_dim):
|
||||
def init_resampler(self, embed_dim: int, vision_dim: int):
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(torch.float16)
|
||||
if self.version == 2.0:
|
||||
if self.version == (2, 0):
|
||||
resampler = Resampler(grid_size=int(
|
||||
math.sqrt(self.config.query_num)),
|
||||
num_queries=None,
|
||||
@@ -463,11 +512,11 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
return resampler
|
||||
|
||||
def get_vision_embedding(self,
|
||||
pixel_values,
|
||||
patch_attn_mask=None,
|
||||
tgt_sizes=None,
|
||||
version=2.0):
|
||||
if version == 2.0:
|
||||
pixel_values: List[List[torch.Tensor]],
|
||||
patch_attn_mask: Optional[torch.Tensor] = None,
|
||||
tgt_sizes: Optional[torch.Tensor] = None,
|
||||
version: Tuple[int, int] = (2, 0)):
|
||||
if version == (2, 0):
|
||||
res = []
|
||||
dtype = self.vpm.pos_embed.data.dtype
|
||||
for pixel_value in pixel_values:
|
||||
@@ -484,21 +533,32 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
num_prefix_tokens:]
|
||||
res.append(self.resampler(vision_embedding, tgt_size))
|
||||
return torch.vstack(res)
|
||||
else:
|
||||
elif version == (2, 5):
|
||||
vision_embedding = self.vpm(
|
||||
pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask).last_hidden_state
|
||||
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
||||
else:
|
||||
vision_embedding = self.vpm(pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask,
|
||||
tgt_sizes=tgt_sizes).last_hidden_state
|
||||
|
||||
def get_image_bounds(self, input_ids):
|
||||
def get_image_bounds(self, input_ids: torch.Tensor):
|
||||
tokenizer = cached_get_tokenizer(self.config._name_or_path,
|
||||
trust_remote_code=True)
|
||||
im_start_token_id = tokenizer.im_start_id
|
||||
im_end_token_id = tokenizer.im_end_id
|
||||
image_start_tokens = torch.where(input_ids == im_start_token_id)[0]
|
||||
if not hasattr(tokenizer, "slice_start_id"):
|
||||
start_cond = input_ids == tokenizer.im_start_id
|
||||
end_cond = input_ids == tokenizer.im_end_id
|
||||
else:
|
||||
start_cond = (input_ids == tokenizer.im_start_id) | (
|
||||
input_ids == tokenizer.slice_start_id)
|
||||
end_cond = (input_ids == tokenizer.im_end_id) | (
|
||||
input_ids == tokenizer.slice_end_id)
|
||||
|
||||
image_start_tokens = torch.where(start_cond)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(input_ids == im_end_token_id)[0]
|
||||
valid_image_nums = min(len(image_start_tokens), len(image_end_tokens))
|
||||
image_end_tokens = torch.where(end_cond)[0]
|
||||
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
||||
if valid_image_nums == 0:
|
||||
return []
|
||||
image_bound = torch.hstack([
|
||||
@@ -508,12 +568,14 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
|
||||
return image_bound
|
||||
|
||||
def get_vision_hidden_states(self, data):
|
||||
def get_vision_hidden_states(self, data: Dict[str,
|
||||
Union[List[torch.Tensor],
|
||||
torch.Tensor]]):
|
||||
if "vision_hidden_states" not in data:
|
||||
pixel_values = data["pixel_values"]
|
||||
tgt_sizes = data["tgt_sizes"]
|
||||
vision_hidden_states = []
|
||||
if self.version == 2.0:
|
||||
if self.version == (2, 0):
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
vision_hidden_states = self.get_vision_embedding(
|
||||
pixel_values)
|
||||
@@ -534,17 +596,26 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
B, L, _ = all_pixel_values.shape
|
||||
all_pixel_values = all_pixel_values.permute(
|
||||
0, 2, 1).reshape(B, 3, -1, L)
|
||||
|
||||
patch_attn_mask = torch.zeros((B, 1, max_patches),
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
for i in range(B):
|
||||
patch_attn_mask[i, :tgt_sizes[i][0] *
|
||||
tgt_sizes[i][1]] = True
|
||||
if self.version == (2, 5):
|
||||
for i in range(B):
|
||||
patch_attn_mask[i, :tgt_sizes[i][0] *
|
||||
tgt_sizes[i][1]] = True
|
||||
vision_embedding = self.vpm(
|
||||
all_pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask
|
||||
).last_hidden_state
|
||||
else:
|
||||
for i in range(B):
|
||||
patch_attn_mask[i, 0, :tgt_sizes[i][0] *
|
||||
tgt_sizes[i][1]] = True
|
||||
vision_embedding = self.vpm(
|
||||
all_pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask,
|
||||
tgt_sizes=tgt_sizes).last_hidden_state
|
||||
|
||||
vision_embedding = self.vpm(
|
||||
all_pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask).last_hidden_state
|
||||
vision_hidden_states = self.resampler(
|
||||
vision_embedding, tgt_sizes)
|
||||
|
||||
@@ -556,7 +627,8 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
|
||||
return vision_hidden_states
|
||||
|
||||
def get_embedding(self, data):
|
||||
def get_embedding(self, data: Dict[str, Union[List[torch.Tensor],
|
||||
torch.Tensor]]):
|
||||
input_ids = data["input_ids"]
|
||||
|
||||
vision_hidden_states = self.get_vision_hidden_states(data)
|
||||
@@ -565,11 +637,11 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
else:
|
||||
image_bounds = []
|
||||
|
||||
if hasattr(self.llm.config, 'scale_emb'):
|
||||
vlm_embedding = self.llm.model.embed_tokens(
|
||||
input_ids) * self.llm.config.scale_emb
|
||||
if hasattr(self.config, 'scale_emb'):
|
||||
vlm_embedding = self.llm.embed_tokens(
|
||||
input_ids) * self.config.scale_emb
|
||||
else:
|
||||
vlm_embedding = self.llm.model.embed_tokens(input_ids)
|
||||
vlm_embedding = self.llm.embed_tokens(input_ids)
|
||||
vision_hidden_states = [
|
||||
i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i
|
||||
for i in vision_hidden_states
|
||||
@@ -587,7 +659,9 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]))
|
||||
return vlm_embedding, vision_hidden_states
|
||||
|
||||
def process_multimodal_inputs(self, inputs):
|
||||
def process_multimodal_inputs(self, inputs: Dict[str,
|
||||
Union[List[torch.Tensor],
|
||||
torch.Tensor]]):
|
||||
pixel_values = []
|
||||
tgt_sizes = []
|
||||
for b in range(len(inputs["pixel_values"])):
|
||||
@@ -613,7 +687,6 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
"input_ids": input_ids,
|
||||
"tgt_sizes": kwargs.pop("tgt_sizes", None),
|
||||
}
|
||||
|
||||
inputs = self.process_multimodal_inputs(inputs)
|
||||
|
||||
vlm_embeddings, vision_hidden_states = self.get_embedding(inputs)
|
||||
@@ -623,19 +696,21 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
input_embeds=vlm_embeddings)
|
||||
inputs_embeds=vlm_embeddings)
|
||||
return output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
return self.llm.compute_logits(hidden_states, sampling_metadata)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.llm.sample(logits, sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
@@ -649,9 +724,9 @@ class MiniCPMV(nn.Module, SupportsVision):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
# for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
# if key_to_modify in name:
|
||||
# name = name.replace(key_to_modify, new_key)
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
|
||||
Reference in New Issue
Block a user