[Meta] Official Eagle mm support, first enablement on llama4 (#20788)
Signed-off-by: morgendave <morgendave@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
@@ -13,6 +13,38 @@ except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
|
||||
QUESTION = "What is the content of each image?"
|
||||
IMAGE_URLS = [
|
||||
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
|
||||
]
|
||||
|
||||
|
||||
def get_custom_mm_prompts(num_prompts):
|
||||
prompts = []
|
||||
for url in IMAGE_URLS:
|
||||
prompts.append(
|
||||
[
|
||||
{"type": "image_url", "image_url": {"url": url}},
|
||||
{"type": "text", "text": QUESTION},
|
||||
]
|
||||
)
|
||||
if num_prompts > len(IMAGE_URLS):
|
||||
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
|
||||
|
||||
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
add_dataset_parser(parser)
|
||||
@@ -35,6 +67,7 @@ def parse_args():
|
||||
parser.add_argument("--output-len", type=int, default=256)
|
||||
parser.add_argument("--model-dir", type=str, default=None)
|
||||
parser.add_argument("--eagle-dir", type=str, default=None)
|
||||
parser.add_argument("--custom-mm-prompts", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -44,14 +77,26 @@ def main():
|
||||
|
||||
model_dir = args.model_dir
|
||||
if args.model_dir is None:
|
||||
if args.custom_mm_prompts:
|
||||
raise ValueError(
|
||||
"custom_mm_prompts requires mm based models"
|
||||
"default llama3.1-8b-instruct is not mm based"
|
||||
"please specify model_dir to give a mm based model"
|
||||
)
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
args.custom_skip_chat_template = True
|
||||
|
||||
prompts = get_samples(args, tokenizer)
|
||||
# add_special_tokens is False to avoid adding bos twice when using chat templates
|
||||
prompt_ids = [
|
||||
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
|
||||
]
|
||||
if not args.custom_mm_prompts:
|
||||
prompts = get_samples(args, tokenizer)
|
||||
# add_special_tokens is False to avoid adding bos twice
|
||||
# when using chat templates
|
||||
prompt_ids = [
|
||||
tokenizer.encode(prompt.prompt, add_special_tokens=False)
|
||||
for prompt in prompts
|
||||
]
|
||||
else:
|
||||
prompts = get_custom_mm_prompts(args.num_prompts)
|
||||
|
||||
if args.method == "eagle" or args.method == "eagle3":
|
||||
eagle_dir = args.eagle_dir
|
||||
@@ -85,10 +130,17 @@ def main():
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
max_model_len=16384,
|
||||
limit_mm_per_prompt={"image": 5},
|
||||
disable_chunked_mm_input=True,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
|
||||
if not args.custom_mm_prompts:
|
||||
outputs = llm.generate(
|
||||
prompt_token_ids=prompt_ids, sampling_params=sampling_params
|
||||
)
|
||||
else:
|
||||
outputs = llm.chat(prompts, sampling_params=sampling_params)
|
||||
|
||||
# print the generated text
|
||||
if args.print_output:
|
||||
|
||||
Reference in New Issue
Block a user