Files
vllm/examples/pooling/score/colqwen3_rerank_online.py

259 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
Example of using ColQwen3 late interaction model for reranking and scoring.
ColQwen3 is a multi-modal ColBERT-style model based on Qwen3-VL.
It produces per-token embeddings and uses MaxSim scoring for retrieval
and reranking. Supports both text and image inputs.
Start the server with:
vllm serve TomoroAI/tomoro-colqwen3-embed-4b --max-model-len 50000
Then run this script:
python colqwen3_rerank_online.py
"""
import base64
from io import BytesIO
import requests
from PIL import Image
MODEL = "TomoroAI/tomoro-colqwen3-embed-4b"
BASE_URL = "http://127.0.0.1:8000"
headers = {"accept": "application/json", "Content-Type": "application/json"}
# ── Image helpers ──────────────────────────────────────────
def load_image(url: str) -> Image.Image:
"""Download an image from URL (handles Wikimedia 403)."""
for hdrs in (
{},
{"User-Agent": "Mozilla/5.0 (compatible; ColQwen3-demo/1.0)"},
):
resp = requests.get(url, headers=hdrs, timeout=15)
if resp.status_code == 403:
continue
resp.raise_for_status()
return Image.open(BytesIO(resp.content)).convert("RGB")
raise RuntimeError(f"Could not fetch image from {url}")
def encode_image_base64(image: Image.Image) -> str:
"""Encode a PIL image to a base64 data URI."""
buf = BytesIO()
image.save(buf, format="PNG")
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()
def make_image_content(image_url: str, text: str = "Describe the image.") -> dict:
"""Build a ScoreMultiModalParam dict from an image URL."""
image = load_image(image_url)
return {
"content": [
{
"type": "image_url",
"image_url": {"url": encode_image_base64(image)},
},
{"type": "text", "text": text},
]
}
# ── Sample image URLs ─────────────────────────────────────
IMAGE_URLS = {
"beijing": "https://upload.wikimedia.org/wikipedia/commons/6/61/Beijing_skyline_at_night.JPG",
"london": "https://upload.wikimedia.org/wikipedia/commons/4/49/London_skyline.jpg",
"singapore": "https://upload.wikimedia.org/wikipedia/commons/2/27/Singapore_skyline_2022.jpg",
}
# ── Text-only examples ────────────────────────────────────
def rerank_text():
"""Text-only reranking via /rerank endpoint."""
print("=" * 60)
print("1. Text reranking (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
],
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print("\n Ranked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def score_text():
"""Text-only scoring via /score endpoint."""
print()
print("=" * 60)
print("2. Text scoring (/score)")
print("=" * 60)
query = "What is the capital of France?"
documents = [
"The capital of France is Paris.",
"Berlin is the capital of Germany.",
"Python is a programming language.",
]
data = {
"model": MODEL,
"text_1": query,
"text_2": documents,
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f"\n Query: {query}\n")
for item in result["data"]:
idx = item["index"]
score = item["score"]
print(f" Doc {idx} (score={score:.4f}): {documents[idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def score_text_top_n():
"""Text reranking with top_n filtering via /rerank endpoint."""
print()
print("=" * 60)
print("3. Text reranking with top_n=2 (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "What is the capital of France?",
"documents": [
"The capital of France is Paris.",
"Berlin is the capital of Germany.",
"Python is a programming language.",
"The Eiffel Tower is in Paris.",
],
"top_n": 2,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f"\n Top {data['top_n']} results:")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
# ── Multi-modal examples (text query × image documents) ──
def score_text_vs_images():
"""Score a text query against image documents via /score."""
print()
print("=" * 60)
print("4. Multi-modal scoring: text query vs image docs (/score)")
print("=" * 60)
query = "Retrieve the city of Beijing"
labels = list(IMAGE_URLS.keys())
print(f"\n Loading {len(labels)} images...")
image_contents = [make_image_content(IMAGE_URLS[name]) for name in labels]
data = {
"model": MODEL,
"data_1": query,
"data_2": image_contents,
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f'\n Query: "{query}"\n')
for item in result["data"]:
idx = item["index"]
print(f" Doc {idx} [{labels[idx]}] score={item['score']:.4f}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def rerank_text_vs_images():
"""Rerank image documents by a text query via /rerank."""
print()
print("=" * 60)
print("5. Multi-modal reranking: text query vs image docs (/rerank)")
print("=" * 60)
query = "Retrieve the city of London"
labels = list(IMAGE_URLS.keys())
print(f"\n Loading {len(labels)} images...")
image_contents = [make_image_content(IMAGE_URLS[name]) for name in labels]
data = {
"model": MODEL,
"query": query,
"documents": image_contents,
"top_n": 2,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f'\n Query: "{query}"')
print(f" Top {data['top_n']} results:\n")
for item in result["results"]:
idx = item["index"]
print(f" [{item['relevance_score']:.4f}] {labels[idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
# ── Main ──────────────────────────────────────────────────
def main():
# Text-only
rerank_text()
score_text()
score_text_top_n()
# Multi-modal (text query × image documents)
score_text_vs_images()
rerank_text_vs_images()
if __name__ == "__main__":
main()