[Bugfix] Fix prefix strings for quantized VLMs (#9772)

This commit is contained in:
Michael Goin
2024-10-29 19:02:59 -04:00
committed by GitHub
parent 8d7724104a
commit bc73e9821c
20 changed files with 288 additions and 97 deletions

View File

@@ -394,8 +394,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module(config, quant_config)
self.llm = self.init_llm(config,
cache_config,
quant_config,
prefix="llm")
self.vpm = self.init_vision_module(config, quant_config, prefix="vpm")
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
@@ -403,9 +406,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.resampler.to(device="cuda", dtype=param_dtype)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix="llm.lm_head")
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@@ -644,6 +649,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
raise NotImplementedError
@@ -651,6 +657,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
raise NotImplementedError
@@ -690,17 +697,20 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(MiniCPMModel(config,
cache_config=cache_config,
quant_config=quant_config),
quant_config=quant_config,
prefix=prefix),
name="model")
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
# TODO :refactor this vision model
try:
@@ -819,19 +829,23 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(LlamaModel(config,
cache_config=cache_config,
quant_config=quant_config),
quant_config=quant_config,
prefix=prefix),
name="model")
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=prefix)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
@@ -935,20 +949,24 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config),
quant_config=quant_config,
prefix=prefix),
name="model")
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=prefix)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model