[V1] VLM preprocessor hashing (#11020)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Alexander Matveev <alexm@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Alexander Matveev
2024-12-11 19:55:30 -05:00
committed by GitHub
parent 452a723bf2
commit 4e11683368
11 changed files with 332 additions and 48 deletions

View File

@@ -5,6 +5,8 @@ the correct prompt format on vision language models for text generation.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import random
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
@@ -23,7 +25,9 @@ def run_llava(question: str, modality: str):
prompt = f"USER: <image>\n{question}\nASSISTANT:"
llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096)
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@@ -33,7 +37,9 @@ def run_llava_next(question: str, modality: str):
assert modality == "image"
prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@@ -44,7 +50,9 @@ def run_llava_next_video(question: str, modality: str):
assert modality == "video"
prompt = f"USER: <video>\n{question} ASSISTANT:"
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@@ -61,7 +69,8 @@ def run_llava_onevision(question: str, modality: str):
<|im_start|>assistant\n"
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384)
max_model_len=16384,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@@ -71,7 +80,10 @@ def run_fuyu(question: str, modality: str):
assert modality == "image"
prompt = f"{question}\n"
llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2)
llm = LLM(model="adept/fuyu-8b",
max_model_len=2048,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@@ -107,6 +119,7 @@ def run_phi3v(question: str, modality: str):
max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
stop_token_ids = None
return llm, prompt, stop_token_ids
@@ -118,7 +131,8 @@ def run_paligemma(question: str, modality: str):
# PaliGemma has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224")
llm = LLM(model="google/paligemma-3b-mix-224",
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@@ -128,7 +142,9 @@ def run_chameleon(question: str, modality: str):
assert modality == "image"
prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b", max_model_len=4096)
llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@@ -154,6 +170,7 @@ def run_minicpmv(question: str, modality: str):
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0
@@ -186,6 +203,7 @@ def run_h2ovl(question: str, modality: str):
model=model_name,
trust_remote_code=True,
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@@ -211,6 +229,7 @@ def run_internvl(question: str, modality: str):
model=model_name,
trust_remote_code=True,
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@@ -241,6 +260,7 @@ def run_nvlm_d(question: str, modality: str):
trust_remote_code=True,
max_model_len=4096,
tensor_parallel_size=4,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@@ -260,7 +280,8 @@ def run_blip2(question: str, modality: str):
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:"
llm = LLM(model="Salesforce/blip2-opt-2.7b")
llm = LLM(model="Salesforce/blip2-opt-2.7b",
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@@ -274,6 +295,7 @@ def run_qwen_vl(question: str, modality: str):
trust_remote_code=True,
max_model_len=1024,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = f"{question}Picture 1: <img></img>\n"
@@ -296,6 +318,7 @@ def run_qwen2_vl(question: str, modality: str):
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
@@ -315,6 +338,7 @@ def run_pixtral_hf(question: str, modality: str):
llm = LLM(
model=model_name,
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
@@ -338,6 +362,7 @@ def run_mllama(question: str, modality: str):
max_model_len=4096,
max_num_seqs=16,
enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = f"<|image|><|begin_of_text|>{question}"
@@ -355,6 +380,7 @@ def run_molmo(question, modality):
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = question
@@ -371,7 +397,8 @@ def run_glm4v(question: str, modality: str):
max_model_len=2048,
max_num_seqs=2,
trust_remote_code=True,
enforce_eager=True)
enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor)
prompt = question
stop_token_ids = [151329, 151336, 151338]
return llm, prompt, stop_token_ids
@@ -394,6 +421,7 @@ def run_idefics3(question: str, modality: str):
"longest_edge": 3 * 364
},
},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = (
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
@@ -410,7 +438,8 @@ def run_aria(question: str, modality: str):
llm = LLM(model=model_name,
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16")
dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor)
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
"<|im_end|>\n<|im_start|>assistant\n")
@@ -430,6 +459,7 @@ def run_mantis(question: str, modality: str):
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
stop_token_ids = [128009]
return llm, prompt, stop_token_ids
@@ -494,6 +524,35 @@ def get_multi_modal_input(args):
raise ValueError(msg)
def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
"""Repeats images with provided probability of "image_repeat_prob".
Used to simulate hit/miss for the MM preprocessor cache.
"""
assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
no_yes = [0, 1]
probs = [1.0 - image_repeat_prob, image_repeat_prob]
inputs = []
cur_image = data
for i in range(num_prompts):
if image_repeat_prob is not None:
res = random.choices(no_yes, probs)[0]
if res == 0:
# No repeat => Modify one pixel
cur_image = cur_image.copy()
new_val = (i // 256 // 256, i // 256, i % 256)
cur_image.putpixel((0, 0), new_val)
inputs.append({
"prompt": prompt,
"multi_modal_data": {
modality: cur_image
}
})
return inputs
def main(args):
model = args.model_type
if model not in model_example_map:
@@ -524,14 +583,29 @@ def main(args):
else:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
modality: data
},
} for _ in range(args.num_prompts)]
if args.image_repeat_prob is not None:
# Repeat images with specified probability of "image_repeat_prob"
inputs = apply_image_repeat(args.image_repeat_prob,
args.num_prompts, data, prompt,
modality)
else:
# Use the same image for all prompts
inputs = [{
"prompt": prompt,
"multi_modal_data": {
modality: data
},
} for _ in range(args.num_prompts)]
outputs = llm.generate(inputs, sampling_params=sampling_params)
if args.time_generate:
import time
start_time = time.time()
outputs = llm.generate(inputs, sampling_params=sampling_params)
elapsed_time = time.time() - start_time
print("-- generate time = {}".format(elapsed_time))
else:
outputs = llm.generate(inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
@@ -561,5 +635,23 @@ if __name__ == "__main__":
type=int,
default=16,
help='Number of frames to extract from the video.')
parser.add_argument(
'--image-repeat-prob',
type=float,
default=None,
help='Simulates the hit-ratio for multi-modal preprocessor cache'
' (if enabled)')
parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.')
parser.add_argument(
'--time-generate',
action='store_true',
help='If True, then print the total generate() call time')
args = parser.parse_args()
main(args)