[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (#9217)

Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung
2024-10-23 19:27:37 +08:00
committed by GitHub
parent 3ff57ebfca
commit c18e1a3418
18 changed files with 551 additions and 253 deletions

View File

@@ -395,7 +395,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module()
self.vpm = self.init_vision_module(config, quant_config)
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
@@ -647,7 +647,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> nn.Module:
raise NotImplementedError
def init_vision_module(self) -> nn.Module:
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
raise NotImplementedError
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
@@ -693,7 +697,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module:
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
# TODO :refactor this vision model
try:
import timm
@@ -817,8 +825,13 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module:
model = Idefics2VisionTransformer(self.config.vision_config)
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
@@ -929,9 +942,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module:
model = Idefics2VisionTransformer(self.config.vision_config)
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model