[Misc] Add --seed option to offline multi-modal examples (#14934)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-17 18:00:17 +08:00
committed by GitHub
parent 868a8c5b2c
commit 6eaf1e5c52
6 changed files with 537 additions and 315 deletions

View File

@@ -7,11 +7,13 @@ For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os
from dataclasses import asdict
from typing import NamedTuple, Optional
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.lora.request import LoRARequest
from vllm.utils import FlexibleArgumentParser
@@ -23,21 +25,31 @@ question_per_audio_count = {
2: "What sport and what nursery rhyme are referenced?"
}
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompt: str
stop_token_ids: Optional[list[int]] = None
lora_requests: Optional[list[LoRARequest]] = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
# MiniCPM-O
def run_minicpmo(question: str, audio_count: int):
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model_name = "openbmb/MiniCPM-o-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
llm = LLM(model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count})
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
stop_tokens = ['<|im_end|>', '<|endoftext|>']
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
@@ -52,11 +64,16 @@ def run_minicpmo(question: str, audio_count: int):
tokenize=False,
add_generation_prompt=True,
chat_template=audio_chat_template)
return llm, prompt, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
)
# Phi-4-multimodal-instruct
def run_phi4mm(questions: str, audio_count: int):
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process audio inputs.
@@ -67,9 +84,9 @@ def run_phi4mm(questions: str, audio_count: int):
speech_lora_path = os.path.join(model_path, "speech-lora")
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
prompts = f"<|user|>{placeholders}{questions}<|end|><|assistant|>"
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
llm = LLM(
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
max_model_len=4096,
@@ -79,24 +96,24 @@ def run_phi4mm(questions: str, audio_count: int):
lora_extra_vocab_size=0,
limit_mm_per_prompt={"audio": audio_count},
)
lora_request = LoRARequest("speech", 1, speech_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
llm.llm_engine.add_lora(lora_request=lora_request)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompt=prompts,
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
)
# Qwen2-Audio
def run_qwen2_audio(question: str, audio_count: int):
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count})
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
audio_in_prompt = "".join([
f"Audio {idx+1}: "
@@ -107,12 +124,15 @@ def run_qwen2_audio(question: str, audio_count: int):
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
return llm, prompt, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Ultravox 0.5-1B
def run_ultravox(question: str, audio_count: int):
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -124,29 +144,39 @@ def run_ultravox(question: str, audio_count: int):
tokenize=False,
add_generation_prompt=True)
llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=5,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Whisper
def run_whisper(question: str, audio_count: int):
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, (
"Whisper only support single audio input per prompt")
model_name = "openai/whisper-large-v3-turbo"
prompt = "<|startoftranscript|>"
llm = LLM(model=model_name,
max_model_len=448,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids
engine_args = EngineArgs(
model=model_name,
max_model_len=448,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
model_example_map = {
@@ -164,14 +194,24 @@ def main(args):
raise ValueError(f"Model type {model} is not supported.")
audio_count = args.num_audios
llm, prompt, stop_token_ids = model_example_map[model](
question_per_audio_count[audio_count], audio_count)
req_data = model_example_map[model](question_per_audio_count[audio_count],
audio_count)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
# To maintain code compatibility in this script, we add LoRA here.
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
if req_data.lora_requests:
for lora_request in req_data.lora_requests:
llm.llm_engine.add_lora(lora_request=lora_request)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2,
max_tokens=64,
stop_token_ids=stop_token_ids)
stop_token_ids=req_data.stop_token_ids)
mm_data = {}
if audio_count > 0:
@@ -183,7 +223,7 @@ def main(args):
}
assert args.num_prompts > 0
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
if args.num_prompts > 1:
# Batch inference
inputs = [inputs] * args.num_prompts
@@ -214,6 +254,10 @@ if __name__ == "__main__":
default=1,
choices=[0, 1, 2],
help="Number of audio items per prompt.")
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
args = parser.parse_args()
main(args)