[Model] Clean up MiniCPMV (#10751)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-29 12:47:06 +08:00
committed by GitHub
parent c83919c7a6
commit fa6ecb9aa7
7 changed files with 149 additions and 215 deletions

View File

@@ -295,16 +295,29 @@ VLM_TEST_SETTINGS = {
)
],
),
"minicpmv": VLMTestInfo(
"minicpmv_25": VLMTestInfo(
models=["openbmb/MiniCPM-Llama3-V-2_5"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: [tok.eos_id, tok.eot_id],
postprocess_inputs=model_utils.wrap_inputs_post_processor,
hf_output_post_proc=model_utils.minicmpv_trunc_hf_output,
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
),
"minicpmv_26": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
postprocess_inputs=model_utils.ignore_inputs_post_processor(
"image_sizes"
),
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
),
# Tests for phi3v currently live in another file because of a bug in
# transformers. Once this issue is fixed, we can enable them here instead.

View File

@@ -170,7 +170,7 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput,
####### Post-processors for HF outputs
def minicmpv_trunc_hf_output(hf_output: RunnerOutput,
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<|eot_id|>"):
@@ -197,6 +197,17 @@ def get_key_type_post_processor(
return process
def ignore_inputs_post_processor(
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
"""Gets a handle to a post processor which ignores a given key."""
def process(hf_inputs: BatchEncoding, dtype: str):
del hf_inputs[hf_inp_key]
return hf_inputs
return process
def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
return {"model_inputs": hf_inputs}