[model] Support MiniCPM-V 4.0 (#22166)

Co-authored-by: imning3 <hbning@pku.edu.cn>
This commit is contained in:
tc-mb
2025-08-07 09:35:46 +08:00
committed by GitHub
parent e8961e963a
commit 41b67f4263
3 changed files with 140 additions and 12 deletions

View File

@@ -38,6 +38,8 @@ from typing_extensions import TypeVar
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
@@ -339,7 +341,9 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
mm_limits = {"image": None}
if self.get_model_version() == (2, 6):
if self.get_model_version() == (2,
6) or self.get_model_version() == (4,
0):
mm_limits["video"] = None
return mm_limits
@@ -620,7 +624,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys: set[str],
) -> dict[str, NestedTensors]:
# This processor supports zipping prompt and mm_data together
if self.info.get_model_version() == (2, 6):
if self.info.get_model_version() == (
2, 6) or self.info.get_model_version() == (4, 0):
inputs = super()._call_hf_processor(
prompt=prompts, # type: ignore
mm_data=mm_data,
@@ -679,10 +684,18 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
placeholder = {
"image": self.info.image_pattern,
"video": self.info.video_pattern,
}
placeholders = [("image", self.info.image_pattern),
("video", self.info.video_pattern)]
# hard code for inconsistency of encode-decode image_pattern
additional_placeholders = []
tokenizer = self.info.get_tokenizer()
for modality, pattern in placeholders:
sub_pattern = tokenizer.decode(
tokenizer.encode(pattern, add_special_tokens=False))
if sub_pattern != pattern:
additional_placeholders.append((modality, sub_pattern))
placeholders += additional_placeholders
def get_image_replacement(item_idx: int):
images = mm_items.get_items(
@@ -714,9 +727,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
return [
PromptReplacement(modality=modality,
target=placeholder[modality],
target=pattern,
replacement=get_replacement[modality])
for modality in ("image", "video")
for modality, pattern in placeholders
]
def _get_mm_fields_config(
@@ -1262,11 +1275,124 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
return self.resampler(vision_embedding, tgt_sizes)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self,
skip_prefixes=["apm.", "audio", "tts"])
return loader.load_weights(weights)
class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (4, 0)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)):
return None
return quant_config
def init_llm(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
quant_config = self._maybe_ignore_quant_config(quant_config)
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config,
prefix=prefix)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(
self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
quant_config = self._maybe_ignore_quant_config(quant_config)
with set_default_torch_dtype(torch.float16):
# The resampler in 4.0 remains consistent with the one in 2.5/2.6.
resampler = Resampler2_5(num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
B = len(pixel_values)
P = pixel_values[0].shape[-2]
L = max(item.shape[-1] for item in pixel_values)
device = pixel_values[0].device
dtype = pixel_values[0].dtype
all_pixel_values = torch.zeros((B, 3, P, L),
dtype=dtype,
device=device)
for i, pixel_values_item in enumerate(pixel_values):
L_item = pixel_values_item.shape[-1]
all_pixel_values[i, ..., :L_item] = pixel_values_item
num_patches = tgt_sizes.prod(-1)
max_patches = num_patches.max().item()
assert isinstance(max_patches, int)
patch_attn_mask = torch.zeros((B, max_patches),
dtype=torch.bool,
device=device)
for i, num_patches_item in enumerate(num_patches):
patch_attn_mask[i, :num_patches_item] = True
vision_embedding = self.vpm(
all_pixel_values,
patch_attention_mask=patch_attn_mask.unsqueeze(1),
tgt_sizes=tgt_sizes,
)
return self.resampler(vision_embedding, tgt_sizes)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self,
skip_prefixes=["apm.", "audio", "tts"])
return loader.load_weights(weights)
_SUPPORT_VERSION = {
(2, 0): MiniCPMV2_0,
(2, 5): MiniCPMV2_5,
(2, 6): MiniCPMV2_6,
(4, 0): MiniCPMV4_0,
}
@@ -1294,8 +1420,10 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
# Dispatch class based on version
instance_cls = _SUPPORT_VERSION.get(version)
if instance_cls is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
supported_versions = ", ".join(
[f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())])
raise ValueError(f"Currently, MiniCPMV only supports versions "
f"{supported_versions}. Got version: {version}")
# quant_config references base class members,
# so update values before init is called