[Frontend] Support multimodal inputs for late-interaction scoring (ColQwen3) + NewModel: nvidia/nemotron-colembed (#34574)

Signed-off-by: craftsangjae <craftsangjae@gmail.com>
This commit is contained in:
Kata Coder
2026-02-21 13:01:40 +09:00
committed by GitHub
parent 11be2c74dc
commit 5719a4e4e6
10 changed files with 532 additions and 66 deletions

View File

@@ -382,6 +382,7 @@ ColQwen3 is based on [ColPali](https://arxiv.org/abs/2407.01449), which extends
|---|---|---| |---|---|---|
| `ColQwen3` | Qwen3-VL | `TomoroAI/tomoro-colqwen3-embed-4b`, `TomoroAI/tomoro-colqwen3-embed-8b` | | `ColQwen3` | Qwen3-VL | `TomoroAI/tomoro-colqwen3-embed-4b`, `TomoroAI/tomoro-colqwen3-embed-8b` |
| `OpsColQwen3Model` | Qwen3-VL | `OpenSearch-AI/Ops-Colqwen3-4B`, `OpenSearch-AI/Ops-Colqwen3-8B` | | `OpsColQwen3Model` | Qwen3-VL | `OpenSearch-AI/Ops-Colqwen3-4B`, `OpenSearch-AI/Ops-Colqwen3-8B` |
| `Qwen3VLNemotronEmbedModel` | Qwen3-VL | `nvidia/nemotron-colembed-vl-4b-v2`, `nvidia/nemotron-colembed-vl-8b-v2` |
Start the server: Start the server:
@@ -389,7 +390,9 @@ Start the server:
vllm serve TomoroAI/tomoro-colqwen3-embed-4b --max-model-len 4096 vllm serve TomoroAI/tomoro-colqwen3-embed-4b --max-model-len 4096
``` ```
Then you can use the rerank endpoint: #### Text-only scoring and reranking
Use the `/rerank` endpoint:
```shell ```shell
curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{ curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{
@@ -403,7 +406,7 @@ curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{
}' }'
``` ```
Or the score endpoint: Or the `/score` endpoint:
```shell ```shell
curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{ curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{
@@ -413,7 +416,57 @@ curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{
}' }'
``` ```
You can also get the raw token embeddings using the pooling endpoint with `token_embed` task: #### Multi-modal scoring and reranking (text query × image documents)
The `/score` and `/rerank` endpoints also accept multi-modal inputs directly.
Pass image documents using the `data_1`/`data_2` (for `/score`) or `documents` (for `/rerank`) fields
with a `content` list containing `image_url` and `text` parts — the same format used by the
OpenAI chat completion API:
Score a text query against image documents:
```shell
curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{
"model": "TomoroAI/tomoro-colqwen3-embed-4b",
"data_1": "Retrieve the city of Beijing",
"data_2": [
{
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<BASE64>"}},
{"type": "text", "text": "Describe the image."}
]
}
]
}'
```
Rerank image documents by a text query:
```shell
curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{
"model": "TomoroAI/tomoro-colqwen3-embed-4b",
"query": "Retrieve the city of Beijing",
"documents": [
{
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<BASE64_1>"}},
{"type": "text", "text": "Describe the image."}
]
},
{
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<BASE64_2>"}},
{"type": "text", "text": "Describe the image."}
]
}
],
"top_n": 2
}'
```
#### Raw token embeddings
You can also get the raw token embeddings using the `/pooling` endpoint with `token_embed` task:
```shell ```shell
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{ curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
@@ -423,7 +476,7 @@ curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
}' }'
``` ```
For **image inputs**, use the chat-style `messages` field so that the vLLM multimodal processor handles them correctly: For **image inputs** via the pooling endpoint, use the chat-style `messages` field:
```shell ```shell
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{ curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
@@ -440,10 +493,10 @@ curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
}' }'
``` ```
Examples can be found here: #### Examples
- Multi-vector retrieval: [examples/pooling/token_embed/colqwen3_token_embed_online.py](../../examples/pooling/token_embed/colqwen3_token_embed_online.py) - Multi-vector retrieval: [examples/pooling/token_embed/colqwen3_token_embed_online.py](../../examples/pooling/token_embed/colqwen3_token_embed_online.py)
- Reranking: [examples/pooling/score/colqwen3_rerank_online.py](../../examples/pooling/score/colqwen3_rerank_online.py) - Reranking (text + multi-modal): [examples/pooling/score/colqwen3_rerank_online.py](../../examples/pooling/score/colqwen3_rerank_online.py)
### BAAI/bge-m3 ### BAAI/bge-m3

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
""" """
Example of using ColQwen3 late interaction model for reranking. Example of using ColQwen3 late interaction model for reranking and scoring.
ColQwen3 is a multi-modal ColBERT-style model based on Qwen3-VL. ColQwen3 is a multi-modal ColBERT-style model based on Qwen3-VL.
It produces per-token embeddings and uses MaxSim scoring for retrieval It produces per-token embeddings and uses MaxSim scoring for retrieval
@@ -14,13 +15,65 @@ Then run this script:
python colqwen3_rerank_online.py python colqwen3_rerank_online.py
""" """
import base64
from io import BytesIO
import requests import requests
from PIL import Image
MODEL = "TomoroAI/tomoro-colqwen3-embed-4b" MODEL = "TomoroAI/tomoro-colqwen3-embed-4b"
BASE_URL = "http://127.0.0.1:8000" BASE_URL = "http://127.0.0.1:8000"
headers = {"accept": "application/json", "Content-Type": "application/json"} headers = {"accept": "application/json", "Content-Type": "application/json"}
# ── Image helpers ──────────────────────────────────────────
def load_image(url: str) -> Image.Image:
"""Download an image from URL (handles Wikimedia 403)."""
for hdrs in (
{},
{"User-Agent": "Mozilla/5.0 (compatible; ColQwen3-demo/1.0)"},
):
resp = requests.get(url, headers=hdrs, timeout=15)
if resp.status_code == 403:
continue
resp.raise_for_status()
return Image.open(BytesIO(resp.content)).convert("RGB")
raise RuntimeError(f"Could not fetch image from {url}")
def encode_image_base64(image: Image.Image) -> str:
"""Encode a PIL image to a base64 data URI."""
buf = BytesIO()
image.save(buf, format="PNG")
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()
def make_image_content(image_url: str, text: str = "Describe the image.") -> dict:
"""Build a ScoreMultiModalParam dict from an image URL."""
image = load_image(image_url)
return {
"content": [
{
"type": "image_url",
"image_url": {"url": encode_image_base64(image)},
},
{"type": "text", "text": text},
]
}
# ── Sample image URLs ─────────────────────────────────────
IMAGE_URLS = {
"beijing": "https://upload.wikimedia.org/wikipedia/commons/6/61/Beijing_skyline_at_night.JPG",
"london": "https://upload.wikimedia.org/wikipedia/commons/4/49/London_skyline.jpg",
"singapore": "https://upload.wikimedia.org/wikipedia/commons/2/27/Singapore_skyline_2022.jpg",
}
# ── Text-only examples ────────────────────────────────────
def rerank_text(): def rerank_text():
"""Text-only reranking via /rerank endpoint.""" """Text-only reranking via /rerank endpoint."""
@@ -120,11 +173,86 @@ def score_text_top_n():
print(f" {response.text[:300]}") print(f" {response.text[:300]}")
# ── Multi-modal examples (text query × image documents) ──
def score_text_vs_images():
"""Score a text query against image documents via /score."""
print()
print("=" * 60)
print("4. Multi-modal scoring: text query vs image docs (/score)")
print("=" * 60)
query = "Retrieve the city of Beijing"
labels = list(IMAGE_URLS.keys())
print(f"\n Loading {len(labels)} images...")
image_contents = [make_image_content(IMAGE_URLS[name]) for name in labels]
data = {
"model": MODEL,
"data_1": query,
"data_2": image_contents,
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f'\n Query: "{query}"\n')
for item in result["data"]:
idx = item["index"]
print(f" Doc {idx} [{labels[idx]}] score={item['score']:.4f}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def rerank_text_vs_images():
"""Rerank image documents by a text query via /rerank."""
print()
print("=" * 60)
print("5. Multi-modal reranking: text query vs image docs (/rerank)")
print("=" * 60)
query = "Retrieve the city of London"
labels = list(IMAGE_URLS.keys())
print(f"\n Loading {len(labels)} images...")
image_contents = [make_image_content(IMAGE_URLS[name]) for name in labels]
data = {
"model": MODEL,
"query": query,
"documents": image_contents,
"top_n": 2,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f'\n Query: "{query}"')
print(f" Top {data['top_n']} results:\n")
for item in result["results"]:
idx = item["index"]
print(f" [{item['relevance_score']:.4f}] {labels[idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
# ── Main ──────────────────────────────────────────────────
def main(): def main():
# Text-only
rerank_text() rerank_text()
score_text() score_text()
score_text_top_n() score_text_top_n()
# Multi-modal (text query × image documents)
score_text_vs_images()
rerank_text_vs_images()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -7,19 +7,31 @@ ColBERT-style late interaction scoring (MaxSim). It produces per-token
embeddings for both text and image inputs. embeddings for both text and image inputs.
""" """
import base64
from io import BytesIO
import pytest import pytest
import torch import torch
from PIL import Image
from vllm.entrypoints.chat_utils import (
ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam,
)
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
from ....conftest import VllmRunner from ....conftest import VllmRunner
MODELS = [ MODELS = [
"TomoroAI/tomoro-colqwen3-embed-4b", "TomoroAI/tomoro-colqwen3-embed-4b",
"OpenSearch-AI/Ops-Colqwen3-4B", "OpenSearch-AI/Ops-Colqwen3-4B",
"nvidia/nemotron-colembed-vl-4b-v2",
] ]
EMBED_DIMS = { EMBED_DIMS = {
"TomoroAI/tomoro-colqwen3-embed-4b": 320, "TomoroAI/tomoro-colqwen3-embed-4b": 320,
"OpenSearch-AI/Ops-Colqwen3-4B": 2560, "OpenSearch-AI/Ops-Colqwen3-4B": 2560,
"nvidia/nemotron-colembed-vl-4b-v2": 2560,
} }
TEXT_QUERIES = [ TEXT_QUERIES = [
@@ -33,6 +45,43 @@ TEXT_DOCUMENTS = [
] ]
DTYPE = "half" DTYPE = "half"
GPU_MEMORY_UTILIZATION = 0.7
def _make_base64_image(
width: int = 64, height: int = 64, color: tuple[int, int, int] = (255, 0, 0)
) -> str:
"""Create a small solid-color PNG image and return its base64 data URI."""
img = Image.new("RGB", (width, height), color)
buf = BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"
def _make_image_mm_param(
image_uri: str,
text: str | None = None,
) -> ScoreMultiModalParam:
"""Build a ScoreMultiModalParam containing an image (and optional text)."""
content: list = [
ChatCompletionContentPartImageParam(
type="image_url",
image_url={"url": image_uri},
),
]
if text is not None:
content.append(
ChatCompletionContentPartTextParam(type="text", text=text),
)
return ScoreMultiModalParam(content=content)
def _make_text_mm_param(text: str) -> ScoreMultiModalParam:
"""Build a ScoreMultiModalParam containing only text."""
return ScoreMultiModalParam(
content=[ChatCompletionContentPartTextParam(type="text", text=text)],
)
def _run_token_embed_test( def _run_token_embed_test(
@@ -48,6 +97,7 @@ def _run_token_embed_test(
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=4096,
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model: ) as vllm_model:
outputs = vllm_model.token_embed([TEXT_QUERIES[0]]) outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
@@ -83,6 +133,7 @@ def _run_late_interaction_test(
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=4096,
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model: ) as vllm_model:
q_outputs = vllm_model.token_embed([TEXT_QUERIES[0]]) q_outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
d_outputs = vllm_model.token_embed([TEXT_DOCUMENTS[0]]) d_outputs = vllm_model.token_embed([TEXT_DOCUMENTS[0]])
@@ -118,6 +169,7 @@ def _run_relevance_test(
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=4096,
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model: ) as vllm_model:
scores = vllm_model.score(query, documents) scores = vllm_model.score(query, documents)
@@ -154,3 +206,142 @@ def test_colqwen3_relevance_ordering(
dtype: str, dtype: str,
) -> None: ) -> None:
_run_relevance_test(vllm_runner, model, dtype=dtype) _run_relevance_test(vllm_runner, model, dtype=dtype)
# ── Multimodal scoring tests ────────────────────────────────
def _run_multimodal_text_query_image_docs_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Score a text query against image documents via the multimodal path.
Verifies that score_data_to_prompts correctly handles image content
and produces valid MaxSim scores.
"""
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
blue_image = _make_base64_image(64, 64, color=(0, 0, 255))
query = "Describe the red object"
image_docs = [
_make_image_mm_param(red_image),
_make_image_mm_param(blue_image),
]
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model:
scores = vllm_model.llm.score(query, image_docs)
assert len(scores) == 2
for s in scores:
assert isinstance(s.outputs.score, float)
def _run_multimodal_mixed_docs_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Score a text query against a mix of text and image documents.
Ensures the late-interaction path handles heterogeneous document
types (plain strings alongside ScoreMultiModalParam images) in
a single call.
"""
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
query = "What is the capital of France?"
documents: list = [
"The capital of France is Paris.",
_make_image_mm_param(red_image),
]
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model:
scores = vllm_model.llm.score(query, documents)
assert len(scores) == 2
for s in scores:
assert isinstance(s.outputs.score, float)
# Text document about France should score higher than a random image
assert scores[0].outputs.score > scores[1].outputs.score
def _run_multimodal_image_query_text_docs_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Score an image query against text documents.
Verifies the reverse direction: multimodal query with text-only
documents through the late-interaction scoring path.
"""
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
image_query = _make_image_mm_param(red_image, text="red color")
documents = [
"A bright red sports car.",
"The weather forecast shows rain tomorrow.",
]
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model:
scores = vllm_model.llm.score(image_query, documents)
assert len(scores) == 2
for s in scores:
assert isinstance(s.outputs.score, float)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colqwen3_multimodal_text_query_image_docs(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_multimodal_text_query_image_docs_test(vllm_runner, model, dtype=dtype)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colqwen3_multimodal_mixed_docs(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_multimodal_mixed_docs_test(vllm_runner, model, dtype=dtype)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colqwen3_multimodal_image_query_text_docs(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_multimodal_image_query_text_docs_test(vllm_runner, model, dtype=dtype)

View File

@@ -603,6 +603,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
"OpsColQwen3Model": _HfExamplesInfo( "OpsColQwen3Model": _HfExamplesInfo(
"OpenSearch-AI/Ops-Colqwen3-4B", trust_remote_code=True "OpenSearch-AI/Ops-Colqwen3-4B", trust_remote_code=True
), ),
"Qwen3VLNemotronEmbedModel": _HfExamplesInfo(
"nvidia/nemotron-colembed-vl-4b-v2",
),
"SiglipModel": _HfExamplesInfo("google/siglip-base-patch16-224"), "SiglipModel": _HfExamplesInfo("google/siglip-base-patch16-224"),
"PrithviGeoSpatialMAE": _HfExamplesInfo( "PrithviGeoSpatialMAE": _HfExamplesInfo(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",

View File

@@ -50,6 +50,7 @@ from vllm.entrypoints.pooling.score.utils import (
compress_token_type_ids, compress_token_type_ids,
compute_maxsim_score, compute_maxsim_score,
get_score_prompt, get_score_prompt,
score_data_to_prompts,
validate_score_input, validate_score_input,
) )
from vllm.entrypoints.utils import log_non_default_args from vllm.entrypoints.utils import log_non_default_args
@@ -1395,25 +1396,13 @@ class LLM:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
# Extract text from ScoreData # Convert ScoreData to PromptType (handles both text and multimodal)
text_1: list[str] = [] model_config = self.model_config
for text in data_1: prompts_1 = score_data_to_prompts(data_1, "query", model_config)
if not isinstance(text, str): prompts_2 = score_data_to_prompts(data_2, "document", model_config)
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
text_1.append(text)
text_2: list[str] = [] encoded_output: list[PoolingRequestOutput] = self.encode(
for text in data_2: prompts_1 + prompts_2,
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
text_2.append(text)
encoded_output = self.encode(
text_1 + text_2,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
pooling_params=pooling_params, pooling_params=pooling_params,
@@ -1421,8 +1410,8 @@ class LLM:
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
encoded_output_1 = encoded_output[0 : len(text_1)] encoded_output_1: list[PoolingRequestOutput] = encoded_output[: len(prompts_1)]
encoded_output_2 = encoded_output[len(text_1) :] encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(prompts_1) :]
if len(encoded_output_1) == 1: if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2) encoded_output_1 = encoded_output_1 * len(encoded_output_2)

View File

@@ -33,6 +33,7 @@ from vllm.entrypoints.pooling.score.utils import (
compress_token_type_ids, compress_token_type_ids,
compute_maxsim_score, compute_maxsim_score,
get_score_prompt, get_score_prompt,
parse_score_data_single,
validate_score_input, validate_score_input,
) )
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
@@ -174,6 +175,43 @@ class ServingScores(OpenAIServing):
return final_res_batch return final_res_batch
def _preprocess_late_interaction_item(
self,
data: ScoreData,
role: str,
request: RerankRequest | ScoreRequest,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
) -> tuple[str, TokensPrompt]:
"""Parse a single ScoreData into a text + optional multimodal
TokensPrompt for late-interaction encoding.
For plain strings, tokenises directly.
For multimodal content parts, extracts text and multi_modal_data.
"""
model_config = self.model_config
if isinstance(data, str):
text, mm_data, mm_uuids = data, None, None
else:
text, mm_data, mm_uuids = parse_score_data_single(data, role, model_config)
prompt_inputs = tokenizer(text, **tokenization_kwargs)
self._validate_input(request, prompt_inputs["input_ids"], text)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
)
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
return text, engine_prompt
async def _late_interaction_score( async def _late_interaction_score(
self, self,
data_1: list[ScoreData], data_1: list[ScoreData],
@@ -189,37 +227,36 @@ class ServingScores(OpenAIServing):
Encodes queries and documents into per-token embeddings, then computes Encodes queries and documents into per-token embeddings, then computes
MaxSim: sum over query tokens of max similarity to any document token. MaxSim: sum over query tokens of max similarity to any document token.
""" """
input_texts: list[str] = []
for text in data_1 + data_2:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
input_texts.append(text)
model_config = self.model_config model_config = self.model_config
tokenizer = self.renderer.get_tokenizer() tokenizer = self.renderer.get_tokenizer()
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
encode_async = make_async( all_data = data_1 + data_2
tokenizer.encode, roles = ["query"] * len(data_1) + ["document"] * len(data_2)
preprocess_async = make_async(
self._preprocess_late_interaction_item,
executor=self._tokenizer_executor, executor=self._tokenizer_executor,
) )
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs() preprocessed = await asyncio.gather(
tokenized_prompts = await asyncio.gather( *(
*(encode_async(t, **tokenization_kwargs) for t in input_texts) preprocess_async(
data=d,
role=r,
request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)
for d, r in zip(all_data, roles)
)
) )
engine_prompts: list[ProcessorInputs] = [] input_texts: list[str] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts): engine_prompts: list[TokensPrompt] = []
text_token_prompt = self._validate_input(request, tok_result, input_text) for text, engine_prompt in preprocessed:
input_texts.append(text)
engine_prompts.append( engine_prompts.append(engine_prompt)
token_inputs(
text_token_prompt["prompt_token_ids"],
prompt=input_text,
)
)
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []

View File

@@ -21,6 +21,7 @@ from vllm.entrypoints.chat_utils import (
_parse_chat_message_content_parts, _parse_chat_message_content_parts,
) )
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType, TextPrompt
from vllm.model_executor.models.interfaces import supports_score_template from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
@@ -153,31 +154,91 @@ def validate_score_input(
return score_input_1, score_input_2 return score_input_1, score_input_2
def _ensure_str(content: list[ConversationMessage]) -> str:
"""Extract a single string prompt from parsed conversation content."""
assert len(content) == 1
prompt = content[0]["content"]
if prompt is not None and isinstance(prompt, str):
return cast(str, prompt)
raise ValueError(f"Only string content is supported, but got {content}.")
def parse_score_data( def parse_score_data(
data_1: ScoreData, data_1: ScoreData,
data_2: ScoreData, data_2: ScoreData,
model_config: ModelConfig, model_config: ModelConfig,
) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]: ) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
"""Parse a query-document pair into text prompts and shared multi-modal
data.
Uses a **single** :class:`MultiModalItemTracker` so that multi-modal
items from both inputs are merged into one ``mm_data`` dict. This is
the correct behaviour for cross-encoder scoring, where query and
document are concatenated into a single model prompt.
"""
mm_tracker = MultiModalItemTracker(model_config) mm_tracker = MultiModalItemTracker(model_config)
content_1 = _parse_score_content("query", data_1, mm_tracker) content_1 = _parse_score_content("query", data_1, mm_tracker)
content_2 = _parse_score_content("document", data_2, mm_tracker) content_2 = _parse_score_content("document", data_2, mm_tracker)
def ensure_str(content: list[ConversationMessage]) -> str: prompt_1 = _ensure_str(content_1)
assert len(content) == 1 prompt_2 = _ensure_str(content_2)
prompt = content[0]["content"]
if prompt is not None and isinstance(prompt, str):
return cast(str, prompt)
else:
raise ValueError(f"Only string content is supported, but got {content}.")
prompt_1 = ensure_str(content_1)
prompt_2 = ensure_str(content_2)
mm_items, mm_uuids = mm_tracker.resolve_items() mm_items, mm_uuids = mm_tracker.resolve_items()
return prompt_1, prompt_2, mm_items, mm_uuids return prompt_1, prompt_2, mm_items, mm_uuids
def parse_score_data_single(
data: ScoreData,
role: str,
model_config: ModelConfig,
) -> tuple[str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
"""Parse **one** ScoreData into a text prompt and its own multi-modal
data.
Unlike :func:`parse_score_data`, each call creates an **independent**
:class:`MultiModalItemTracker` so multi-modal items are kept separate.
This is the correct behaviour for late-interaction scoring, where
query and document are encoded independently.
"""
mm_tracker = MultiModalItemTracker(model_config)
content = _parse_score_content(role, data, mm_tracker)
prompt = _ensure_str(content)
mm_items, mm_uuids = mm_tracker.resolve_items()
return prompt, mm_items, mm_uuids
def score_data_to_prompts(
data_list: list[ScoreData],
role: str,
model_config: ModelConfig,
) -> list[PromptType]:
"""Convert a list of ScoreData into PromptType objects.
For plain text inputs, returns the string directly.
For multimodal inputs (list of content parts), parses them into
a :class:`TextPrompt` with attached ``multi_modal_data`` /
``multi_modal_uuids``.
This is used by late-interaction scoring where each query/document
is encoded independently.
"""
prompts: list[PromptType] = []
for data in data_list:
if isinstance(data, str):
prompts.append(data)
else:
text, mm_data, mm_uuids = parse_score_data_single(data, role, model_config)
prompt: TextPrompt = TextPrompt(prompt=text)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
prompts.append(prompt)
return prompts
def _parse_score_content( def _parse_score_content(
role: str, role: str,
data: ScoreData, data: ScoreData,

View File

@@ -16,6 +16,7 @@ Based on: Qwen3-VL backbone with custom text projection
Target models: Target models:
- TomoroAI/tomoro-colqwen3-embed-8b - TomoroAI/tomoro-colqwen3-embed-8b
- OpenSearch-AI/Ops-Colqwen3-4B - OpenSearch-AI/Ops-Colqwen3-4B
- nvidia/nemotron-colembed-vl-4b-v2
""" """
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
@@ -229,13 +230,14 @@ class ColQwen3Model(
if not isinstance(hidden_states, torch.Tensor): if not isinstance(hidden_states, torch.Tensor):
return hidden_states # type: ignore return hidden_states # type: ignore
proj_dtype = self.custom_text_proj.weight.dtype # type: ignore if self.custom_text_proj is not None:
proj_dtype = self.custom_text_proj.weight.dtype
if hidden_states.dtype != proj_dtype: if hidden_states.dtype != proj_dtype:
hidden_states = hidden_states.to(proj_dtype) hidden_states = hidden_states.to(proj_dtype)
hidden_states = self.custom_text_proj(hidden_states)
# Project to embedding dimension and L2 normalize # L2 normalize
proj = self.custom_text_proj(hidden_states) # type: ignore return torch.nn.functional.normalize(hidden_states, p=2, dim=-1)
return torch.nn.functional.normalize(proj, p=2, dim=-1)
# Names used for the projection layer across different ColQwen3 variants # Names used for the projection layer across different ColQwen3 variants
_PROJ_LAYER_NAMES = { _PROJ_LAYER_NAMES = {

View File

@@ -256,6 +256,7 @@ _EMBEDDING_MODELS = {
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"ColQwen3": ("colqwen3", "ColQwen3Model"), "ColQwen3": ("colqwen3", "ColQwen3Model"),
"OpsColQwen3Model": ("colqwen3", "ColQwen3Model"), "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
"Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
"SiglipModel": ("siglip", "SiglipEmbeddingModel"), "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
# Technically Terratorch models work on images, both in # Technically Terratorch models work on images, both in
# input and output. I am adding it here because it piggy-backs on embedding # input and output. I am adding it here because it piggy-backs on embedding

View File

@@ -76,6 +76,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
chatglm="ChatGLMConfig", chatglm="ChatGLMConfig",
colqwen3="ColQwen3Config", colqwen3="ColQwen3Config",
ops_colqwen3="OpsColQwen3Config", ops_colqwen3="OpsColQwen3Config",
qwen3_vl_nemotron_embed="Qwen3VLNemotronEmbedConfig",
deepseek_vl_v2="DeepseekVLV2Config", deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v32="DeepseekV3Config", deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig", flex_olmo="FlexOlmoConfig",