[Bugfix] Fix deepseek-ocr multi-image inference and add merge_by_field_config=True with tensor schema support (#27361)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple):
|
||||
stop_token_ids: list[int] | None = None
|
||||
chat_template: str | None = None
|
||||
lora_requests: list[LoRARequest] | None = None
|
||||
sampling_params: SamplingParams | None = None
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
@@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
|
||||
|
||||
model_name = "deepseek-ai/DeepSeek-OCR"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
logits_processors=[NGramPerReqLogitsProcessor],
|
||||
)
|
||||
|
||||
placeholder = "<image>\n" * len(image_urls)
|
||||
prompt = placeholder + question
|
||||
|
||||
# The following sampling params config is taken from
|
||||
# the official Deepseek-OCR inference example.
|
||||
# (IMPORTANT) Use the custom logits processor and avoid skipping
|
||||
# special tokens for this model for the optimal OCR performance.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=8192,
|
||||
# ngram logit processor args
|
||||
extra_args=dict(
|
||||
ngram_size=30,
|
||||
window_size=90,
|
||||
# whitelist: <td>, </td>
|
||||
whitelist_token_ids={128821, 128822},
|
||||
),
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
|
||||
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "google/gemma-3-4b-it"
|
||||
|
||||
@@ -1253,6 +1294,7 @@ model_example_map = {
|
||||
"bee": load_bee,
|
||||
"command_a_vision": load_command_a_vision,
|
||||
"deepseek_vl_v2": load_deepseek_vl2,
|
||||
"deepseek_ocr": load_deepseek_ocr,
|
||||
"gemma3": load_gemma3,
|
||||
"h2ovl_chat": load_h2ovl,
|
||||
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
|
||||
@@ -1325,8 +1367,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
|
||||
sampling_params = (
|
||||
SamplingParams(
|
||||
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
if req_data.sampling_params is None
|
||||
else req_data.sampling_params
|
||||
)
|
||||
outputs = llm.chat(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user