[Model] Add multi-image support for minicpmv (#7122)
Co-authored-by: hezhihui <hzh7269@modelbest.cn> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -392,6 +392,20 @@ class Resampler2_5(BaseResampler):
|
||||
return x
|
||||
|
||||
|
||||
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
||||
version_float = getattr(config, "version", None)
|
||||
|
||||
# The old configs do not include version number
|
||||
# TODO: Remove this after the HF repos are updated
|
||||
if version_float is None:
|
||||
if config.hidden_size == 2304 and config.query_num == 64:
|
||||
return (2, 0)
|
||||
return (2, 5)
|
||||
|
||||
version_str = str(version_float)
|
||||
return tuple(int(x) for x in version_str.split("."))
|
||||
|
||||
|
||||
def get_max_minicpmv_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
return getattr(hf_config, "query_num", 64)
|
||||
@@ -421,36 +435,43 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
|
||||
version = get_version_by_config(model_config.hf_config)
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
image_processor = cached_get_image_processor(model_config.tokenizer)
|
||||
|
||||
def get_placeholder(image_size: Tuple[int, int], num_image: int):
|
||||
if version == (2, 0) or version == (2, 5):
|
||||
return image_processor. \
|
||||
get_slice_image_placeholder(image_size)
|
||||
return image_processor. \
|
||||
get_slice_image_placeholder(image_size, num_image)
|
||||
|
||||
prompt = llm_inputs.get("prompt")
|
||||
if prompt is None:
|
||||
token_ids = llm_inputs.get("prompt_token_ids")
|
||||
prompt = tokenizer.decode(token_ids)
|
||||
image_processor = cached_get_image_processor(model_config.tokenizer)
|
||||
|
||||
pattern = "(<image>./</image>)"
|
||||
image = multi_modal_data["image"]
|
||||
images = multi_modal_data["image"]
|
||||
if isinstance(images, Image.Image):
|
||||
images = [images]
|
||||
image_tags = re.findall(pattern, prompt)
|
||||
|
||||
if len(image_tags) == 0:
|
||||
new_token_ids = token_ids
|
||||
new_prompt = prompt
|
||||
else:
|
||||
if len(image_tags) > 1:
|
||||
logger.warning("Multiple image input is not supported yet, "
|
||||
"so any extra image tokens will be treated "
|
||||
"as plain text.")
|
||||
|
||||
text_chunks = prompt.split(pattern)
|
||||
new_prompt = (text_chunks[0] +
|
||||
image_processor.get_slice_image_placeholder(image.size) +
|
||||
"".join(text_chunks[1:]))
|
||||
|
||||
new_prompt_chunks: List[str] = []
|
||||
for i in range(len(images)):
|
||||
new_prompt_chunks += [
|
||||
text_chunks[i],
|
||||
get_placeholder(images[i].size, i)
|
||||
]
|
||||
new_prompt_chunks.append(text_chunks[-1])
|
||||
new_prompt = "".join(new_prompt_chunks)
|
||||
new_token_ids = tokenizer.encode(new_prompt)
|
||||
|
||||
llm_inputs = LLMInputs(
|
||||
@@ -478,14 +499,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
if not hasattr(self.config, "version"):
|
||||
if self.config.hidden_size == 2304 and self.config.query_num == 64:
|
||||
self.version = (2, 0)
|
||||
else:
|
||||
self.version = (2, 5)
|
||||
else:
|
||||
self.version = str(self.config.version).split(".")
|
||||
self.version = tuple([int(x) for x in self.version])
|
||||
self.version = get_version_by_config(self.config)
|
||||
self.llm = self.init_llm(config, cache_config, quant_config)
|
||||
self.vpm = self.init_vision_module()
|
||||
param_dtype = torch.get_default_dtype()
|
||||
|
||||
Reference in New Issue
Block a user