[Bugfix] Clean up MiniCPM-V (#6939)

Co-authored-by: hezhihui <hzh7269@modelbest.cn>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alphi
2024-07-31 22:39:19 +08:00
committed by GitHub
parent 6512937de1
commit 2f4e108f75
6 changed files with 975 additions and 94 deletions

View File

@@ -20,32 +20,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-V-2 model compatible with HuggingFace weights."""
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from functools import partial
from typing import Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.types
from PIL import Image
from torch import nn
from torch.nn.init import trunc_normal_
from transformers.configuration_utils import PretrainedConfig
from transformers.models.idefics2.modeling_idefics2 import (
Idefics2VisionTransformer)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_image_processor,
@@ -53,12 +55,12 @@ from vllm.multimodal.image import (cached_get_image_processor,
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
"llm.lm_head": "lm_head",
"llm.model": "llm",
}
def get_abs_pos(abs_pos, tgt_size):
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
@@ -75,10 +77,10 @@ def get_abs_pos(abs_pos, tgt_size):
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim,
grid_size,
cls_token=False,
version=2.0):
def get_2d_sincos_pos_embed(embed_dim: int,
grid_size: Union[int, Tuple[int, int]],
cls_token: bool = False,
version: Tuple[int, int] = (2, 0)):
"""
grid_size: int of the grid height and width
return:
@@ -95,7 +97,7 @@ def get_2d_sincos_pos_embed(embed_dim,
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
if version == 2.0:
if version == (2, 0):
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
if cls_token:
@@ -106,7 +108,9 @@ def get_2d_sincos_pos_embed(embed_dim,
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0):
def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
grid: Union[int, Tuple[int, int]],
version: Tuple[int, int] = (2, 0)):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
@@ -115,14 +119,16 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0):
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
if version == 2.0:
if version == (2, 0):
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
else:
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0):
def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
pos: int,
version: Tuple[int, int] = (2, 0)):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
@@ -133,7 +139,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0):
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
if version == 2.0:
if version == (2, 0):
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
@@ -158,19 +164,19 @@ class Resampler(nn.Module):
default_norm_layer = partial(nn.LayerNorm, eps=1e-6)
def __init__(self,
num_queries,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=default_norm_layer,
adaptive=False,
max_size=(70, 70),
version=2.0):
num_queries: int,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: nn.Module = default_norm_layer,
adaptive: bool = False,
max_size: Tuple[int, int] = (70, 70),
version: Tuple[int, int] = (2, 0)):
super().__init__()
self.version = version
if self.version == 2.0:
if self.version == (2, 0):
self.num_queries = grid_size**2
else:
self.num_queries = num_queries
@@ -195,7 +201,7 @@ class Resampler(nn.Module):
self.proj = nn.Parameter(
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
if self.version == 2.0:
if self.version == (2, 0):
self.pos_embed = nn.Parameter(
torch.from_numpy(
get_2d_sincos_pos_embed(
@@ -206,14 +212,17 @@ class Resampler(nn.Module):
self.apply(self._init_weights)
def _set_2d_pos_cache(self, max_size, device='cpu'):
def _set_2d_pos_cache(self,
max_size: Tuple[int, int],
device: torch.types.Device = 'cpu'):
pos_embed = torch.from_numpy(
get_2d_sincos_pos_embed(self.embed_dim,
max_size,
version=self.version)).float().to(device)
self.register_buffer("pos_embed", pos_embed, persistent=False)
def _adjust_pos_cache(self, tgt_sizes, device):
def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
device: torch.types.Device):
max_h = torch.max(tgt_sizes[:, 0])
max_w = torch.max(tgt_sizes[:, 1])
if max_h > self.max_size[0] or max_w > self.max_size[1]:
@@ -223,7 +232,7 @@ class Resampler(nn.Module):
]
self._set_2d_pos_cache(self.max_size, device)
def _init_weights(self, m):
def _init_weights(self, m: nn.Module):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
@@ -232,7 +241,9 @@ class Resampler(nn.Module):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_2_5(self, x, tgt_sizes=None):
def forward_2_5(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None):
assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0]
@@ -278,7 +289,10 @@ class Resampler(nn.Module):
x = x @ self.proj
return x
def forward_2(self, x, tgt_sizes=None, attn_mask=None):
def forward_2(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None):
if self.adaptive:
pos_embed = torch.Tensor(
get_2d_sincos_pos_embed(self.embed_dim,
@@ -302,8 +316,11 @@ class Resampler(nn.Module):
x = x @ self.proj
return x
def forward(self, x, tgt_sizes=None, attn_mask=None):
if self.version == 2.0:
def forward(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None):
if self.version == (2, 0):
return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask)
else:
return self.forward_2_5(x, tgt_sizes=tgt_sizes)
@@ -322,7 +339,7 @@ def dummy_seq_data_for_minicpmv(seq_len: int):
return SequenceData(token_ids)
def dummy_image_for_minicpmv(hf_config):
def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
width = height = hf_config.image_size
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
@@ -381,7 +398,7 @@ class MiniCPMV(nn.Module, SupportsVision):
def __init__(
self,
config,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
@@ -390,30 +407,48 @@ class MiniCPMV(nn.Module, SupportsVision):
self.config = config
self.multimodal_config = multimodal_config
self.version = float(self.config.version)
if not hasattr(self.config, "version"):
if self.config.hidden_size == 2304 and self.config.query_num == 64:
self.version = (2, 0)
else:
self.version = (2, 5)
else:
self.version = str(self.config.version).split(".")
self.version = tuple([int(x) for x in self.version])
self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module()
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = self.vpm.embed_dim if self.version == 2.0 \
self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \
else self.vpm.embeddings.embed_dim
self.embed_dim = self.llm.config.hidden_size
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.resampler.to(device="cuda", dtype=param_dtype)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def init_llm(self, config, cache_config, quant_config):
if self.version == 2.0:
return MiniCPMForCausalLM(config,
cache_config=cache_config,
quant_config=quant_config)
def init_llm(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
if self.version == (2, 0):
return MiniCPMModel(config,
cache_config=cache_config,
quant_config=quant_config)
elif self.version == (2, 5):
return LlamaModel(config,
cache_config=cache_config,
quant_config=quant_config)
else:
return LlamaForCausalLM(config,
cache_config=cache_config,
quant_config=quant_config)
return Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config)
def init_vision_module(self):
if self.version == 2.0:
if self.version == (2, 0):
try:
import timm
except ImportError:
@@ -433,16 +468,30 @@ class MiniCPMV(nn.Module, SupportsVision):
if self.config.drop_vision_last_layer:
model.blocks = model.blocks[:-1]
else:
elif self.version == (2, 5):
from transformers.models.idefics2.modeling_idefics2 import (
Idefics2VisionTransformer)
model = Idefics2VisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
else:
from vllm.model_executor.models.na_vit import (
SiglipVisionTransformer)
if self.config._attn_implementation == 'flash_attention_2':
self.config.vision_config._attn_implementation \
= 'flash_attention_2'
else:
# not support sdpa
self.config.vision_config._attn_implementation = 'eager'
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self, embed_dim, vision_dim):
def init_resampler(self, embed_dim: int, vision_dim: int):
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float16)
if self.version == 2.0:
if self.version == (2, 0):
resampler = Resampler(grid_size=int(
math.sqrt(self.config.query_num)),
num_queries=None,
@@ -463,11 +512,11 @@ class MiniCPMV(nn.Module, SupportsVision):
return resampler
def get_vision_embedding(self,
pixel_values,
patch_attn_mask=None,
tgt_sizes=None,
version=2.0):
if version == 2.0:
pixel_values: List[List[torch.Tensor]],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
version: Tuple[int, int] = (2, 0)):
if version == (2, 0):
res = []
dtype = self.vpm.pos_embed.data.dtype
for pixel_value in pixel_values:
@@ -484,21 +533,32 @@ class MiniCPMV(nn.Module, SupportsVision):
num_prefix_tokens:]
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)
else:
elif version == (2, 5):
vision_embedding = self.vpm(
pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
else:
vision_embedding = self.vpm(pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes).last_hidden_state
def get_image_bounds(self, input_ids):
def get_image_bounds(self, input_ids: torch.Tensor):
tokenizer = cached_get_tokenizer(self.config._name_or_path,
trust_remote_code=True)
im_start_token_id = tokenizer.im_start_id
im_end_token_id = tokenizer.im_end_id
image_start_tokens = torch.where(input_ids == im_start_token_id)[0]
if not hasattr(tokenizer, "slice_start_id"):
start_cond = input_ids == tokenizer.im_start_id
end_cond = input_ids == tokenizer.im_end_id
else:
start_cond = (input_ids == tokenizer.im_start_id) | (
input_ids == tokenizer.slice_start_id)
end_cond = (input_ids == tokenizer.im_end_id) | (
input_ids == tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(input_ids == im_end_token_id)[0]
valid_image_nums = min(len(image_start_tokens), len(image_end_tokens))
image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
if valid_image_nums == 0:
return []
image_bound = torch.hstack([
@@ -508,12 +568,14 @@ class MiniCPMV(nn.Module, SupportsVision):
return image_bound
def get_vision_hidden_states(self, data):
def get_vision_hidden_states(self, data: Dict[str,
Union[List[torch.Tensor],
torch.Tensor]]):
if "vision_hidden_states" not in data:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
vision_hidden_states = []
if self.version == 2.0:
if self.version == (2, 0):
if pixel_values is not None and len(pixel_values) > 0:
vision_hidden_states = self.get_vision_embedding(
pixel_values)
@@ -534,17 +596,26 @@ class MiniCPMV(nn.Module, SupportsVision):
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(
0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
dtype=torch.bool,
device=device)
for i in range(B):
patch_attn_mask[i, :tgt_sizes[i][0] *
tgt_sizes[i][1]] = True
if self.version == (2, 5):
for i in range(B):
patch_attn_mask[i, :tgt_sizes[i][0] *
tgt_sizes[i][1]] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask
).last_hidden_state
else:
for i in range(B):
patch_attn_mask[i, 0, :tgt_sizes[i][0] *
tgt_sizes[i][1]] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes).last_hidden_state
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask).last_hidden_state
vision_hidden_states = self.resampler(
vision_embedding, tgt_sizes)
@@ -556,7 +627,8 @@ class MiniCPMV(nn.Module, SupportsVision):
return vision_hidden_states
def get_embedding(self, data):
def get_embedding(self, data: Dict[str, Union[List[torch.Tensor],
torch.Tensor]]):
input_ids = data["input_ids"]
vision_hidden_states = self.get_vision_hidden_states(data)
@@ -565,11 +637,11 @@ class MiniCPMV(nn.Module, SupportsVision):
else:
image_bounds = []
if hasattr(self.llm.config, 'scale_emb'):
vlm_embedding = self.llm.model.embed_tokens(
input_ids) * self.llm.config.scale_emb
if hasattr(self.config, 'scale_emb'):
vlm_embedding = self.llm.embed_tokens(
input_ids) * self.config.scale_emb
else:
vlm_embedding = self.llm.model.embed_tokens(input_ids)
vlm_embedding = self.llm.embed_tokens(input_ids)
vision_hidden_states = [
i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i
for i in vision_hidden_states
@@ -587,7 +659,9 @@ class MiniCPMV(nn.Module, SupportsVision):
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]))
return vlm_embedding, vision_hidden_states
def process_multimodal_inputs(self, inputs):
def process_multimodal_inputs(self, inputs: Dict[str,
Union[List[torch.Tensor],
torch.Tensor]]):
pixel_values = []
tgt_sizes = []
for b in range(len(inputs["pixel_values"])):
@@ -613,7 +687,6 @@ class MiniCPMV(nn.Module, SupportsVision):
"input_ids": input_ids,
"tgt_sizes": kwargs.pop("tgt_sizes", None),
}
inputs = self.process_multimodal_inputs(inputs)
vlm_embeddings, vision_hidden_states = self.get_embedding(inputs)
@@ -623,19 +696,21 @@ class MiniCPMV(nn.Module, SupportsVision):
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
input_embeds=vlm_embeddings)
inputs_embeds=vlm_embeddings)
return output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
return self.llm.compute_logits(hidden_states, sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.llm.sample(logits, sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -649,9 +724,9 @@ class MiniCPMV(nn.Module, SupportsVision):
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
# if key_to_modify in name:
# name = name.replace(key_to_modify, new_key)
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name