[Model] Rename MiniCPMVQwen2 to MiniCPMV2.6 (#7273)

This commit is contained in:
Jee Jee Li
2024-08-08 22:02:41 +08:00
committed by GitHub
parent 6dffa4b0a6
commit 757ac70a64
3 changed files with 51 additions and 31 deletions

View File

@@ -216,7 +216,6 @@ class BaseResampler(nn.Module):
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
else:
@@ -225,7 +224,6 @@ class BaseResampler(nn.Module):
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
@@ -261,7 +259,6 @@ class Resampler2(BaseResampler):
norm_layer)
self.adaptive = adaptive
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
grid_size,
version=(2, 0))
@@ -717,7 +714,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
raise NotImplementedError
class MiniCPMV2(MiniCPMVBaseModel):
class MiniCPMV2_0(MiniCPMVBaseModel):
def __init__(
self,
@@ -890,10 +887,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
return "resampler" in name
# NOTE: Currently, information about this model is unavailable. We are
# temporarily using `MiniCPMVQwen2` as it's name. The name may need
# to be modified in the future.
class MiniCPMVQwen2(MiniCPMVBaseModel):
class MiniCPMV2_6(MiniCPMVBaseModel):
def __init__(
self,
@@ -903,6 +897,7 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(config, multimodal_config, cache_config, quant_config)
assert self.version == (2, 6)
def init_llm(
self,
@@ -930,6 +925,7 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
with set_default_torch_dtype(torch.float16):
# The resampler in 2.6 remains consistent with the one in 2.5.
resampler = Resampler2_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
@@ -989,6 +985,13 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
return "resampler" in name or "vpm" in name
_SUPPORT_VERSION = {
(2, 0): MiniCPMV2_0,
(2, 5): MiniCPMV2_5,
(2, 6): MiniCPMV2_6
}
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@@ -1016,11 +1019,9 @@ class MiniCPMV(MiniCPMVBaseModel):
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
# Dispatch class based on version
if version == (2, 0):
instance_class = MiniCPMV2
elif version == (2, 5):
instance_class = MiniCPMV2_5
else:
instance_class = MiniCPMVQwen2
instance_class = _SUPPORT_VERSION.get(version, None)
if instance_class is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(config, multimodal_config, cache_config,
quant_config)