[Misc] Add --seed option to offline multi-modal examples (#14934)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -7,11 +7,12 @@ For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict
|
||||
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
|
||||
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm import LLM
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@@ -37,12 +38,12 @@ Query = Union[TextQuery, ImageQuery, TextImageQuery]
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
llm: LLM
|
||||
engine_args: EngineArgs
|
||||
prompt: str
|
||||
image: Optional[Image]
|
||||
|
||||
|
||||
def run_e5_v(query: Query):
|
||||
def run_e5_v(query: Query) -> ModelRequestData:
|
||||
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
|
||||
|
||||
if query["modality"] == "text":
|
||||
@@ -58,20 +59,20 @@ def run_e5_v(query: Query):
|
||||
modality = query['modality']
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="royokong/e5-v",
|
||||
task="embed",
|
||||
max_model_len=4096,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
def run_vlm2vec(query: Query):
|
||||
def run_vlm2vec(query: Query) -> ModelRequestData:
|
||||
if query["modality"] == "text":
|
||||
text = query["text"]
|
||||
prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501
|
||||
@@ -87,7 +88,7 @@ def run_vlm2vec(query: Query):
|
||||
modality = query['modality']
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="TIGER-Lab/VLM2Vec-Full",
|
||||
task="embed",
|
||||
trust_remote_code=True,
|
||||
@@ -95,7 +96,7 @@ def run_vlm2vec(query: Query):
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
@@ -126,15 +127,18 @@ def get_query(modality: QueryModality):
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def run_encode(model: str, modality: QueryModality):
|
||||
def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
mm_data = {}
|
||||
if req_data.image is not None:
|
||||
mm_data["image"] = req_data.image
|
||||
|
||||
outputs = req_data.llm.embed({
|
||||
outputs = llm.embed({
|
||||
"prompt": req_data.prompt,
|
||||
"multi_modal_data": mm_data,
|
||||
})
|
||||
@@ -144,7 +148,7 @@ def run_encode(model: str, modality: QueryModality):
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
run_encode(args.model_name, args.modality)
|
||||
run_encode(args.model_name, args.modality, args.seed)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
@@ -167,5 +171,10 @@ if __name__ == "__main__":
|
||||
default="image",
|
||||
choices=get_args(QueryModality),
|
||||
help='Modality of the input.')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user