[Model] Rename MiniCPMVQwen2 to MiniCPMV2.6 (#7273)
This commit is contained in:
@@ -216,7 +216,6 @@ class BaseResampler(nn.Module):
|
||||
|
||||
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
||||
trunc_normal_(self.query, std=0.02)
|
||||
|
||||
if kv_dim is not None and kv_dim != embed_dim:
|
||||
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
|
||||
else:
|
||||
@@ -225,7 +224,6 @@ class BaseResampler(nn.Module):
|
||||
nn.Identity()(*args, **kwargs),
|
||||
None,
|
||||
)
|
||||
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.ln_q = norm_layer(embed_dim)
|
||||
self.ln_kv = norm_layer(embed_dim)
|
||||
@@ -261,7 +259,6 @@ class Resampler2(BaseResampler):
|
||||
norm_layer)
|
||||
|
||||
self.adaptive = adaptive
|
||||
|
||||
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
|
||||
grid_size,
|
||||
version=(2, 0))
|
||||
@@ -717,7 +714,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MiniCPMV2(MiniCPMVBaseModel):
|
||||
class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -890,10 +887,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
|
||||
return "resampler" in name
|
||||
|
||||
|
||||
# NOTE: Currently, information about this model is unavailable. We are
|
||||
# temporarily using `MiniCPMVQwen2` as it's name. The name may need
|
||||
# to be modified in the future.
|
||||
class MiniCPMVQwen2(MiniCPMVBaseModel):
|
||||
class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -903,6 +897,7 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__(config, multimodal_config, cache_config, quant_config)
|
||||
assert self.version == (2, 6)
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
@@ -930,6 +925,7 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
|
||||
|
||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
# The resampler in 2.6 remains consistent with the one in 2.5.
|
||||
resampler = Resampler2_5(
|
||||
num_queries=self.config.query_num,
|
||||
embed_dim=embed_dim,
|
||||
@@ -989,6 +985,13 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
|
||||
return "resampler" in name or "vpm" in name
|
||||
|
||||
|
||||
_SUPPORT_VERSION = {
|
||||
(2, 0): MiniCPMV2_0,
|
||||
(2, 5): MiniCPMV2_5,
|
||||
(2, 6): MiniCPMV2_6
|
||||
}
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
|
||||
@@ -1016,11 +1019,9 @@ class MiniCPMV(MiniCPMVBaseModel):
|
||||
version = str(config.version).split(".")
|
||||
version = tuple([int(x) for x in version])
|
||||
# Dispatch class based on version
|
||||
if version == (2, 0):
|
||||
instance_class = MiniCPMV2
|
||||
elif version == (2, 5):
|
||||
instance_class = MiniCPMV2_5
|
||||
else:
|
||||
instance_class = MiniCPMVQwen2
|
||||
instance_class = _SUPPORT_VERSION.get(version, None)
|
||||
if instance_class is None:
|
||||
raise ValueError(
|
||||
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
|
||||
return instance_class(config, multimodal_config, cache_config,
|
||||
quant_config)
|
||||
|
||||
Reference in New Issue
Block a user