[Bugfix] Fix deepseek-ocr multi-image inference and add merge_by_field_config=True with tensor schema support (#27361)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-10-23 08:15:38 +08:00
committed by GitHub
parent b4fda58a2d
commit 2566dca2a9
4 changed files with 112 additions and 66 deletions

View File

@@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple):
stop_token_ids: list[int] | None = None
chat_template: str | None = None
lora_requests: list[LoRARequest] | None = None
sampling_params: SamplingParams | None = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
@@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
model_name = "deepseek-ai/DeepSeek-OCR"
engine_args = EngineArgs(
model=model_name,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
logits_processors=[NGramPerReqLogitsProcessor],
)
placeholder = "<image>\n" * len(image_urls)
prompt = placeholder + question
# The following sampling params config is taken from
# the official Deepseek-OCR inference example.
# (IMPORTANT) Use the custom logits processor and avoid skipping
# special tokens for this model for the optimal OCR performance.
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
# ngram logit processor args
extra_args=dict(
ngram_size=30,
window_size=90,
# whitelist: <td>, </td>
whitelist_token_ids={128821, 128822},
),
skip_special_tokens=False,
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
sampling_params=sampling_params,
)
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it"
@@ -1253,6 +1294,7 @@ model_example_map = {
"bee": load_bee,
"command_a_vision": load_command_a_vision,
"deepseek_vl_v2": load_deepseek_vl2,
"deepseek_ocr": load_deepseek_ocr,
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
@@ -1325,8 +1367,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args)
sampling_params = SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
sampling_params = (
SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
)
if req_data.sampling_params is None
else req_data.sampling_params
)
outputs = llm.chat(
[