[Model] Clean up MiniCPMV (#10751)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-29 12:47:06 +08:00
committed by GitHub
parent c83919c7a6
commit fa6ecb9aa7
7 changed files with 149 additions and 215 deletions

View File

@@ -22,7 +22,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from functools import partial
from functools import cached_property, partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, TypedDict, Union)
@@ -37,19 +37,15 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.utils import LLMWrapper
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
@@ -58,11 +54,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import is_pp_missing_parameter, maybe_prefix
_KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head",
}
from .utils import AutoWeightsLoader, maybe_prefix
RawImageType = Union[Image.Image, torch.Tensor]
@@ -297,10 +289,9 @@ def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
def get_placeholder(image_size: Tuple[int, int], num_image: int):
if version == (2, 0) or version == (2, 5):
return image_processor. \
get_slice_image_placeholder(image_size)
return image_processor. \
get_slice_image_placeholder(image_size, num_image)
return image_processor.get_slice_image_placeholder(image_size)
return image_processor.get_slice_image_placeholder(
image_size, num_image)
prompt = inputs.get("prompt")
token_ids = inputs.get("prompt_token_ids")
@@ -400,37 +391,32 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.vpm = self.init_vision_module(config,
quant_config,
prefix=maybe_prefix(prefix, "vpm"))
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
self.vpm.embeddings.embed_dim)
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim,
self.vision_dim,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "resampler"))
self.resampler.to(device="cuda", dtype=param_dtype)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "llm.lm_head"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.llm, "sampler"):
return self.llm.sampler
return get_sampler()
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"):
vlm_embedding *= self.config.scale_emb
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
@@ -575,7 +561,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
# for `torch.compile` integration
input_ids = None
output = self.llm(
output = self.llm.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
@@ -590,9 +576,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
return self.llm.compute_logits(hidden_states, sampling_metadata)
def sample(
self,
@@ -604,52 +588,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
use_default_weight_loading = False
if self.is_default_weight_loading(name):
use_default_weight_loading = True
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_mm_mapping(self) -> MultiModelKeys:
"""
@@ -693,9 +633,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
data: MiniCPMVImageInputs) -> torch.Tensor:
raise NotImplementedError
def is_default_weight_loading(self, name: str) -> bool:
raise NotImplementedError
class MiniCPMV2_0(MiniCPMVBaseModel):
@@ -708,8 +645,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix),
name="model")
return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix)
def init_vision_module(
self,
@@ -717,11 +653,12 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
# TODO :refactor this vision model
# TODO: refactor this vision model
try:
import timm
except ImportError:
raise ImportError("Please install timm==0.9.10") from ImportError
with set_default_torch_dtype(torch.float16):
model = timm.create_model(
"vit_so400m_patch14_siglip_384.webli",
@@ -731,6 +668,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
dynamic_img_pad=True,
)
model = model.to(dtype=torch.get_default_dtype())
if (isinstance(model, timm.models.VisionTransformer)
and model.attn_pool is not None):
model.attn_pool = torch.nn.Identity()
@@ -759,7 +698,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config=quant_config,
prefix=prefix)
return resampler
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
@@ -790,9 +729,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return self.get_vision_embedding(pixel_values)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
@@ -843,8 +779,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix),
name="model")
return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
def init_vision_module(
self,
@@ -871,7 +806,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
@@ -913,9 +849,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
return self.get_vision_embedding(all_pixel_values.type(dtype),
patch_attn_mask, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
@@ -966,8 +899,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix),
name="model")
return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix)
def init_vision_module(
self,
@@ -995,7 +927,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
@@ -1043,9 +976,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
return self.resampler(vision_embedding, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name
_SUPPORT_VERSION = {
(2, 0): MiniCPMV2_0,