[Model] Improve multimodal pooling examples (#32085)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
93
examples/pooling/embed/vision_embedding_offline.py
Normal file
93
examples/pooling/embed/vision_embedding_offline.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference with
|
||||
the correct prompt format on vision language models for multimodal embedding.
|
||||
|
||||
For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from dataclasses import asdict
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
|
||||
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/cat_snow.jpg"
|
||||
text = "A cat standing in the snow."
|
||||
multi_modal_data = {"image": fetch_image(image_url)}
|
||||
|
||||
|
||||
def print_embeddings(embeds):
|
||||
embeds_trimmed = (str(embeds[:4])[:-1] + ", ...]") if len(embeds) > 4 else embeds
|
||||
print(f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
|
||||
|
||||
|
||||
def run_qwen3_vl():
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-VL-Embedding-2B",
|
||||
runner="pooling",
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
default_instruction = "Represent the user's input."
|
||||
image_placeholder = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
text_prompt = f"<|im_start|>system\n{default_instruction}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n"
|
||||
image_prompt = f"<|im_start|>system\n{default_instruction}<|im_end|>\n<|im_start|>user\n{image_placeholder}<|im_end|>\n<|im_start|>assistant\n"
|
||||
image_text_prompt = f"<|im_start|>system\n{default_instruction}<|im_end|>\n<|im_start|>user\n{image_placeholder}{text}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
llm = LLM(**asdict(engine_args))
|
||||
|
||||
print("Text embedding output:")
|
||||
outputs = llm.embed(text_prompt, use_tqdm=False)
|
||||
print_embeddings(outputs[0].outputs.embedding)
|
||||
|
||||
print("Image embedding output:")
|
||||
outputs = llm.embed(
|
||||
{
|
||||
"prompt": image_prompt,
|
||||
"multi_modal_data": multi_modal_data,
|
||||
},
|
||||
use_tqdm=False,
|
||||
)
|
||||
print_embeddings(outputs[0].outputs.embedding)
|
||||
|
||||
print("Image+Text embedding output:")
|
||||
outputs = llm.embed(
|
||||
{
|
||||
"prompt": image_text_prompt,
|
||||
"multi_modal_data": multi_modal_data,
|
||||
},
|
||||
use_tqdm=False,
|
||||
)
|
||||
print_embeddings(outputs[0].outputs.embedding)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"qwen3_vl": run_qwen3_vl,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
"Script to run a specified VLM through vLLM offline api."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
choices=model_example_map.keys(),
|
||||
required=True,
|
||||
help="The name of the embedding model.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
model_example_map[args.model]()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -21,7 +21,8 @@ from PIL import Image
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/cat_snow.jpg"
|
||||
text = "A cat standing in the snow."
|
||||
|
||||
|
||||
def create_chat_embeddings(
|
||||
@@ -30,6 +31,8 @@ def create_chat_embeddings(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: str,
|
||||
encoding_format: Literal["base64", "float"] | NotGiven = NOT_GIVEN,
|
||||
continue_final_message: bool = False,
|
||||
add_special_tokens: bool = False,
|
||||
) -> CreateEmbeddingResponse:
|
||||
"""
|
||||
Convenience function for accessing vLLM's Chat Embeddings API,
|
||||
@@ -38,10 +41,21 @@ def create_chat_embeddings(
|
||||
return client.post(
|
||||
"/embeddings",
|
||||
cast_to=CreateEmbeddingResponse,
|
||||
body={"messages": messages, "model": model, "encoding_format": encoding_format},
|
||||
body={
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"encoding_format": encoding_format,
|
||||
"continue_final_message": continue_final_message,
|
||||
"add_special_tokens": add_special_tokens,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def print_embeddings(embeds):
|
||||
embeds_trimmed = (str(embeds[:4])[:-1] + ", ...]") if len(embeds) > 4 else embeds
|
||||
print(f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
|
||||
|
||||
|
||||
def run_clip(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
@@ -145,6 +159,113 @@ def run_dse_qwen2_vl(client: OpenAI, model: str):
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
def run_qwen3_vl(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve Qwen/Qwen3-VL-Embedding-2B \
|
||||
--runner pooling \
|
||||
--max-model-len 8192
|
||||
"""
|
||||
|
||||
default_instruction = "Represent the user's input."
|
||||
|
||||
print("Text embedding output:")
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": default_instruction},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": ""},
|
||||
],
|
||||
},
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
continue_final_message=True,
|
||||
add_special_tokens=True,
|
||||
)
|
||||
print_embeddings(response.data[0].embedding)
|
||||
|
||||
print("Image embedding output:")
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": default_instruction},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": ""},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": ""},
|
||||
],
|
||||
},
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
continue_final_message=True,
|
||||
add_special_tokens=True,
|
||||
)
|
||||
print_embeddings(response.data[0].embedding)
|
||||
|
||||
print("Image+Text embedding output:")
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": default_instruction},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{text}",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": ""},
|
||||
],
|
||||
},
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
continue_final_message=True,
|
||||
add_special_tokens=True,
|
||||
)
|
||||
print_embeddings(response.data[0].embedding)
|
||||
|
||||
|
||||
def run_siglip(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
@@ -213,7 +334,8 @@ def run_vlm2vec(client: OpenAI, model: str):
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image embedding output:", response.data[0].embedding)
|
||||
print("Image embedding output:")
|
||||
print_embeddings(response.data[0].embedding)
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
@@ -233,7 +355,8 @@ def run_vlm2vec(client: OpenAI, model: str):
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image+Text embedding output:", response.data[0].embedding)
|
||||
print("Image+Text embedding output:")
|
||||
print_embeddings(response.data[0].embedding)
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
@@ -249,11 +372,13 @@ def run_vlm2vec(client: OpenAI, model: str):
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
print("Text embedding output:")
|
||||
print_embeddings(response.data[0].embedding)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"clip": run_clip,
|
||||
"qwen3_vl": run_qwen3_vl,
|
||||
"dse_qwen2_vl": run_dse_qwen2_vl,
|
||||
"siglip": run_siglip,
|
||||
"vlm2vec": run_vlm2vec,
|
||||
@@ -1,60 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Example online usage of Score API.
|
||||
|
||||
Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="jinaai/jina-reranker-m0")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/score"
|
||||
model_name = args.model
|
||||
|
||||
text_1 = "slm markdown"
|
||||
text_2 = {
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("\nPrompt when text_1 is string and text_2 is a image list:")
|
||||
pprint.pprint(prompt)
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
80
examples/pooling/score/vision_rerank_api_online.py
Normal file
80
examples/pooling/score/vision_rerank_api_online.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
"""
|
||||
Example Python client for multimodal rerank API which is compatible with
|
||||
Jina and Cohere https://jina.ai/reranker
|
||||
|
||||
Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
|
||||
e.g.
|
||||
vllm serve jinaai/jina-reranker-m0 --runner pooling
|
||||
|
||||
vllm serve Qwen/Qwen3-VL-Reranker-2B \
|
||||
--runner pooling \
|
||||
--max-model-len 4096 \
|
||||
--hf_overrides '{"architectures": ["Qwen3VLForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' \
|
||||
--chat-template examples/pooling/score/template/qwen3_vl_reranker.jinja
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
query = "A woman playing with her dog on a beach at sunset."
|
||||
documents = {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
"A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset, " # noqa: E501
|
||||
"as the dog offers its paw in a heartwarming display of companionship and trust." # noqa: E501
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
models_url = base_url + "/v1/models"
|
||||
rerank_url = base_url + "/rerank"
|
||||
|
||||
response = requests.get(models_url, headers=headers)
|
||||
model = response.json()["data"][0]["id"]
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
response = requests.post(rerank_url, headers=headers, json=data)
|
||||
|
||||
# Check the response
|
||||
if response.status_code == 200:
|
||||
print("Request successful!")
|
||||
print(json.dumps(response.json(), indent=2))
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
71
examples/pooling/score/vision_score_api_online.py
Normal file
71
examples/pooling/score/vision_score_api_online.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
"""
|
||||
Example online usage of Score API.
|
||||
|
||||
Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
|
||||
e.g.
|
||||
vllm serve jinaai/jina-reranker-m0 --runner pooling
|
||||
|
||||
vllm serve Qwen/Qwen3-VL-Reranker-2B \
|
||||
--runner pooling \
|
||||
--max-model-len 4096 \
|
||||
--hf_overrides '{"architectures": ["Qwen3VLForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' \
|
||||
--chat-template examples/pooling/score/template/qwen3_vl_reranker.jinja
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
text_1 = "slm markdown"
|
||||
text_2 = {
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
models_url = base_url + "/v1/models"
|
||||
score_url = base_url + "/score"
|
||||
|
||||
response = requests.get(models_url, headers=headers)
|
||||
model = response.json()["data"][0]["id"]
|
||||
|
||||
prompt = {"model": model, "text_1": text_1, "text_2": text_2}
|
||||
response = requests.post(score_url, headers=headers, json=prompt)
|
||||
print("\nPrompt when text_1 is string and text_2 is a image list:")
|
||||
pprint.pprint(prompt)
|
||||
print("\nScore Response:")
|
||||
print(json.dumps(response.json(), indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user