[Model] Add multi-image support for minicpmv (#7122)

Co-authored-by: hezhihui <hzh7269@modelbest.cn>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alphi
2024-08-05 09:23:17 +08:00
committed by GitHub
parent f80ab3521c
commit 7b86e7c9cd
4 changed files with 172 additions and 37 deletions

View File

@@ -392,6 +392,20 @@ class Resampler2_5(BaseResampler):
return x
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
version_float = getattr(config, "version", None)
# The old configs do not include version number
# TODO: Remove this after the HF repos are updated
if version_float is None:
if config.hidden_size == 2304 and config.query_num == 64:
return (2, 0)
return (2, 5)
version_str = str(version_float)
return tuple(int(x) for x in version_str.split("."))
def get_max_minicpmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
return getattr(hf_config, "query_num", 64)
@@ -421,36 +435,43 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
version = get_version_by_config(model_config.hf_config)
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
image_processor = cached_get_image_processor(model_config.tokenizer)
def get_placeholder(image_size: Tuple[int, int], num_image: int):
if version == (2, 0) or version == (2, 5):
return image_processor. \
get_slice_image_placeholder(image_size)
return image_processor. \
get_slice_image_placeholder(image_size, num_image)
prompt = llm_inputs.get("prompt")
if prompt is None:
token_ids = llm_inputs.get("prompt_token_ids")
prompt = tokenizer.decode(token_ids)
image_processor = cached_get_image_processor(model_config.tokenizer)
pattern = "(<image>./</image>)"
image = multi_modal_data["image"]
images = multi_modal_data["image"]
if isinstance(images, Image.Image):
images = [images]
image_tags = re.findall(pattern, prompt)
if len(image_tags) == 0:
new_token_ids = token_ids
new_prompt = prompt
else:
if len(image_tags) > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")
text_chunks = prompt.split(pattern)
new_prompt = (text_chunks[0] +
image_processor.get_slice_image_placeholder(image.size) +
"".join(text_chunks[1:]))
new_prompt_chunks: List[str] = []
for i in range(len(images)):
new_prompt_chunks += [
text_chunks[i],
get_placeholder(images[i].size, i)
]
new_prompt_chunks.append(text_chunks[-1])
new_prompt = "".join(new_prompt_chunks)
new_token_ids = tokenizer.encode(new_prompt)
llm_inputs = LLMInputs(
@@ -478,14 +499,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
self.config = config
self.multimodal_config = multimodal_config
if not hasattr(self.config, "version"):
if self.config.hidden_size == 2304 and self.config.query_num == 64:
self.version = (2, 0)
else:
self.version = (2, 5)
else:
self.version = str(self.config.version).split(".")
self.version = tuple([int(x) for x in self.version])
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module()
param_dtype = torch.get_default_dtype()