[Model] Clean up MiniCPMV (#10751)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -22,7 +22,7 @@
|
||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from functools import cached_property, partial
|
||||
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
||||
Set, Tuple, TypedDict, Union)
|
||||
|
||||
@@ -37,19 +37,15 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
|
||||
get_2d_sincos_pos_embed)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.minicpm import MiniCPMModel
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.models.utils import LLMWrapper
|
||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
@@ -58,11 +54,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import is_pp_missing_parameter, maybe_prefix
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"llm.lm_head": "lm_head",
|
||||
}
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
RawImageType = Union[Image.Image, torch.Tensor]
|
||||
|
||||
@@ -297,10 +289,9 @@ def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
|
||||
def get_placeholder(image_size: Tuple[int, int], num_image: int):
|
||||
if version == (2, 0) or version == (2, 5):
|
||||
return image_processor. \
|
||||
get_slice_image_placeholder(image_size)
|
||||
return image_processor. \
|
||||
get_slice_image_placeholder(image_size, num_image)
|
||||
return image_processor.get_slice_image_placeholder(image_size)
|
||||
return image_processor.get_slice_image_placeholder(
|
||||
image_size, num_image)
|
||||
|
||||
prompt = inputs.get("prompt")
|
||||
token_ids = inputs.get("prompt_token_ids")
|
||||
@@ -400,37 +391,32 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.vpm = self.init_vision_module(config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vpm"))
|
||||
param_dtype = torch.get_default_dtype()
|
||||
self.vpm.to(dtype=param_dtype)
|
||||
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
|
||||
self.vpm.embeddings.embed_dim)
|
||||
self.embed_dim = self.config.hidden_size
|
||||
|
||||
self.resampler = self.init_resampler(self.embed_dim,
|
||||
self.vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "resampler"))
|
||||
self.resampler.to(device="cuda", dtype=param_dtype)
|
||||
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "llm.lm_head"))
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.llm.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.llm, "sampler"):
|
||||
return self.llm.sampler
|
||||
|
||||
return get_sampler()
|
||||
|
||||
def get_embedding(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_inputs: Optional[MiniCPMVImageInputs],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
|
||||
if hasattr(self.config, "scale_emb"):
|
||||
vlm_embedding *= self.config.scale_emb
|
||||
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
||||
|
||||
if image_inputs is None: # No image
|
||||
vision_hidden_states = torch.tensor([], device=input_ids.device)
|
||||
@@ -575,7 +561,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
|
||||
output = self.llm(
|
||||
output = self.llm.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
@@ -590,9 +576,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
return self.llm.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@@ -604,52 +588,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
use_default_weight_loading = False
|
||||
if self.is_default_weight_loading(name):
|
||||
use_default_weight_loading = True
|
||||
else:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
use_default_weight_loading = True
|
||||
if use_default_weight_loading:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
@@ -693,9 +633,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_default_weight_loading(self, name: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
|
||||
@@ -708,8 +645,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix),
|
||||
name="model")
|
||||
return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def init_vision_module(
|
||||
self,
|
||||
@@ -717,11 +653,12 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
# TODO :refactor this vision model
|
||||
# TODO: refactor this vision model
|
||||
try:
|
||||
import timm
|
||||
except ImportError:
|
||||
raise ImportError("Please install timm==0.9.10") from ImportError
|
||||
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
model = timm.create_model(
|
||||
"vit_so400m_patch14_siglip_384.webli",
|
||||
@@ -731,6 +668,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
dynamic_img_pad=True,
|
||||
)
|
||||
|
||||
model = model.to(dtype=torch.get_default_dtype())
|
||||
|
||||
if (isinstance(model, timm.models.VisionTransformer)
|
||||
and model.attn_pool is not None):
|
||||
model.attn_pool = torch.nn.Identity()
|
||||
@@ -759,7 +698,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
return resampler
|
||||
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
||||
|
||||
def get_vision_embedding(
|
||||
self,
|
||||
@@ -790,9 +729,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
|
||||
return self.get_vision_embedding(pixel_values)
|
||||
|
||||
def is_default_weight_loading(self, name: str) -> bool:
|
||||
return "resampler" in name or "vpm" in name
|
||||
|
||||
|
||||
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
@@ -843,8 +779,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix),
|
||||
name="model")
|
||||
return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def init_vision_module(
|
||||
self,
|
||||
@@ -871,7 +806,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
kv_dim=vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
return resampler
|
||||
|
||||
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
||||
|
||||
def get_vision_embedding(
|
||||
self,
|
||||
@@ -913,9 +849,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
return self.get_vision_embedding(all_pixel_values.type(dtype),
|
||||
patch_attn_mask, tgt_sizes)
|
||||
|
||||
def is_default_weight_loading(self, name: str) -> bool:
|
||||
return "resampler" in name
|
||||
|
||||
|
||||
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
@@ -966,8 +899,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix),
|
||||
name="model")
|
||||
return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def init_vision_module(
|
||||
self,
|
||||
@@ -995,7 +927,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
kv_dim=vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
return resampler
|
||||
|
||||
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
||||
|
||||
def get_vision_embedding(
|
||||
self,
|
||||
@@ -1043,9 +976,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
|
||||
return self.resampler(vision_embedding, tgt_sizes)
|
||||
|
||||
def is_default_weight_loading(self, name: str) -> bool:
|
||||
return "resampler" in name
|
||||
|
||||
|
||||
_SUPPORT_VERSION = {
|
||||
(2, 0): MiniCPMV2_0,
|
||||
|
||||
Reference in New Issue
Block a user