[Bugfix] Use dedicated MM processor cache in /tokenize to prevent sender-cache pollution (#38545)

Signed-off-by: Sergey Zinchenko <sergey.zinchenko.rnd@gmail.com>
This commit is contained in:
Sergey Zinchenko
2026-04-02 07:14:49 +03:00
committed by GitHub
parent 5f96f9aff1
commit 5a2d420c17
4 changed files with 192 additions and 24 deletions

View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Regression test: calling ``/tokenize`` with multimodal data followed by
``/v1/chat/completions`` with the same data must not cause an error.
Ensures that the ``/tokenize`` endpoint does not pollute internal caches
(e.g. multimodal feature caches) and that a subsequent
``/v1/chat/completions`` request with the same multimodal payload
completes successfully.
"""
import json
import openai
import pytest
import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"--max-num-seqs",
"5",
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"image": 1}),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_tokenize_then_chat_completion_with_image(
client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
local_asset_server,
):
"""Tokenize a multimodal message, then send the same message to chat
completions. The chat completion must succeed (not 500)."""
image_url = local_asset_server.url_for("stop_sign.jpg")
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "Describe this image briefly."},
],
}
]
tok_resp = requests.post(
server.url_for("tokenize"),
json={"model": MODEL_NAME, "messages": messages},
)
tok_resp.raise_for_status()
tok_data = tok_resp.json()
assert tok_data["count"] > 0, "Tokenization must return tokens"
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
temperature=0.0,
)
assert chat_completion.choices[0].message.content, (
"Chat completion must produce non-empty content after tokenize"
)

View File

@@ -451,6 +451,8 @@ class OpenAIServingRender:
request: Any,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
*,
skip_mm_cache: bool = False,
) -> list[EngineInput]:
"""Copied from OpenAIServing._preprocess_completion."""
prompts = list[SingletonPrompt | bytes]()
@@ -458,12 +460,14 @@ class OpenAIServingRender:
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
return await self.preprocess_cmpl(request, prompts)
return await self.preprocess_cmpl(request, prompts, skip_mm_cache=skip_mm_cache)
async def preprocess_cmpl(
self,
request: Any,
prompts: Sequence[PromptType | bytes],
*,
skip_mm_cache: bool = False,
) -> list[EngineInput]:
"""Copied from OpenAIServing._preprocess_cmpl."""
renderer = self.renderer
@@ -487,6 +491,7 @@ class OpenAIServingRender:
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
skip_mm_cache=skip_mm_cache,
)
async def preprocess_chat(
@@ -498,6 +503,8 @@ class OpenAIServingRender:
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: type[ToolParser] | None = None,
*,
skip_mm_cache: bool = False,
) -> tuple[list[ConversationMessage], list[EngineInput]]:
"""Copied from OpenAIServing._preprocess_chat."""
renderer = self.renderer
@@ -529,6 +536,7 @@ class OpenAIServingRender:
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
skip_mm_cache=skip_mm_cache,
)
# tool parsing is done only if a tool_parser has been set and if

View File

@@ -86,12 +86,14 @@ class OpenAIServingTokenization(OpenAIServing):
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
skip_mm_cache=True,
)
else:
engine_inputs = await self.openai_serving_render.preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=None,
skip_mm_cache=True,
)
input_ids: list[int] = []

View File

@@ -97,6 +97,7 @@ class BaseRenderer(ABC, Generic[_T]):
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
self.mm_processor: BaseMultiModalProcessor | None = None
self._readonly_mm_processor: BaseMultiModalProcessor | None = None
self._mm_cache_stats: MultiModalCacheStats | None = None
self._clear_mm_cache_async = make_async(
self.clear_mm_cache, executor=self._executor
@@ -124,6 +125,19 @@ class BaseRenderer(ABC, Generic[_T]):
if mm_processor_cache:
self._mm_cache_stats = MultiModalCacheStats()
# A second processor with its own processor-only cache.
# Used by the tokenize endpoint so that tokenize-only
# requests don't pollute the sender cache.
ro_cache = mm_registry.processor_only_cache_from_config(config)
if ro_cache is not None:
ro_tokenizer = copy.deepcopy(tokenizer)
with set_default_torch_num_threads():
self._readonly_mm_processor = mm_registry.create_processor(
config.model_config,
tokenizer=ro_tokenizer,
cache=ro_cache,
)
# This is used to generate internal request ID for MM processing
# It has no relation to the request ID for engine core
self._mm_req_counter = AtomicCounter()
@@ -625,10 +639,15 @@ class BaseRenderer(ABC, Generic[_T]):
mm_uuids: MultiModalUUIDDict | None,
mm_processor_kwargs: Mapping[str, object] | None,
tokenization_kwargs: dict[str, Any] | None,
*,
skip_mm_cache: bool = False,
) -> "MultiModalInput":
mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}"
mm_processor = self.get_mm_processor()
if skip_mm_cache and self._readonly_mm_processor is not None:
mm_processor = self._readonly_mm_processor
else:
mm_processor = self.get_mm_processor()
mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids)
@@ -656,6 +675,8 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_tokens(
self,
prompt: TokensPrompt,
*,
skip_mm_cache: bool = False,
) -> TokensInput | MultiModalInput:
"""Process token inputs, with multimodal preprocessing offloaded
to the shared thread pool in the async variant.
@@ -670,6 +691,7 @@ class BaseRenderer(ABC, Generic[_T]):
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None, # Tokenization already done in Step 2
mm_uuids=prompt.get("multi_modal_uuids"),
skip_mm_cache=skip_mm_cache,
)
else:
engine_input = tokens_input(prompt_token_ids)
@@ -712,6 +734,8 @@ class BaseRenderer(ABC, Generic[_T]):
async def _process_tokens_async(
self,
prompt: TokensPrompt,
*,
skip_mm_cache: bool = False,
) -> TokensInput | MultiModalInput:
prompt_token_ids = prompt["prompt_token_ids"]
@@ -723,6 +747,7 @@ class BaseRenderer(ABC, Generic[_T]):
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None,
mm_uuids=prompt.get("multi_modal_uuids"),
skip_mm_cache=skip_mm_cache,
)
else:
engine_input = tokens_input(prompt_token_ids)
@@ -734,24 +759,33 @@ class BaseRenderer(ABC, Generic[_T]):
return engine_input
def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
return self._process_tokens(prompt) # type: ignore[arg-type]
async def _process_singleton_async(
def _process_singleton(
self,
prompt: SingletonTokPrompt,
*,
skip_mm_cache: bool = False,
) -> SingletonInput:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
return await self._process_tokens_async(prompt) # type: ignore[arg-type]
return self._process_tokens(prompt, skip_mm_cache=skip_mm_cache) # type: ignore[arg-type]
async def _process_singleton_async(
self,
prompt: SingletonTokPrompt,
*,
skip_mm_cache: bool = False,
) -> SingletonInput:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
return await self._process_tokens_async(prompt, skip_mm_cache=skip_mm_cache) # type: ignore[arg-type]
def _process_enc_dec(
self,
prompt: EncoderDecoderTokPrompt,
*,
skip_mm_cache: bool = False,
) -> EncoderDecoderInput:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
@@ -764,9 +798,13 @@ class BaseRenderer(ABC, Generic[_T]):
skip_decoder_start_token = self.mm_processor.skip_decoder_start_token
return build_enc_dec_input(
encoder_input=self._process_singleton(enc_prompt),
encoder_input=self._process_singleton(
enc_prompt, skip_mm_cache=skip_mm_cache
),
decoder_input=(
None if dec_prompt is None else self._process_singleton(dec_prompt)
None
if dec_prompt is None
else self._process_singleton(dec_prompt, skip_mm_cache=skip_mm_cache)
),
decoder_start_token_id=self.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token,
@@ -775,16 +813,20 @@ class BaseRenderer(ABC, Generic[_T]):
async def _process_enc_dec_async(
self,
prompt: EncoderDecoderTokPrompt,
*,
skip_mm_cache: bool = False,
) -> EncoderDecoderInput:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
encoder_input, decoder_input = await asyncio.gather(
self._process_singleton_async(enc_prompt),
self._process_singleton_async(enc_prompt, skip_mm_cache=skip_mm_cache),
(
asyncio.sleep(0)
if dec_prompt is None
else self._process_singleton_async(dec_prompt)
else self._process_singleton_async(
dec_prompt, skip_mm_cache=skip_mm_cache
)
),
)
@@ -794,27 +836,40 @@ class BaseRenderer(ABC, Generic[_T]):
decoder_start_token_id=self.get_dec_start_token_id(),
)
def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput:
def process_for_engine(
self,
prompt: TokPrompt,
arrival_time: float,
*,
skip_mm_cache: bool = False,
) -> EngineInput:
engine_input: EngineInput
if "encoder_prompt" in prompt:
engine_input = self._process_enc_dec(prompt) # type: ignore[arg-type]
engine_input = self._process_enc_dec(prompt, skip_mm_cache=skip_mm_cache) # type: ignore[arg-type]
else:
engine_input = self._process_singleton(prompt)
engine_input = self._process_singleton(prompt, skip_mm_cache=skip_mm_cache)
engine_input["arrival_time"] = arrival_time
return engine_input
async def process_for_engine_async(
self, prompt: TokPrompt, arrival_time: float
self,
prompt: TokPrompt,
arrival_time: float,
*,
skip_mm_cache: bool = False,
) -> EngineInput:
engine_input: EngineInput
if "encoder_prompt" in prompt:
engine_input = await self._process_enc_dec_async(
prompt # type: ignore[arg-type]
prompt, # type: ignore[arg-type]
skip_mm_cache=skip_mm_cache,
)
else:
engine_input = await self._process_singleton_async(prompt)
engine_input = await self._process_singleton_async(
prompt, skip_mm_cache=skip_mm_cache
)
engine_input["arrival_time"] = arrival_time
@@ -827,6 +882,7 @@ class BaseRenderer(ABC, Generic[_T]):
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
skip_mm_cache: bool = False,
):
arrival_time = time.time()
@@ -838,7 +894,10 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
return [
self.process_for_engine(prompt, arrival_time, skip_mm_cache=skip_mm_cache)
for prompt in tok_prompts
]
async def render_cmpl_async(
self,
@@ -846,6 +905,7 @@ class BaseRenderer(ABC, Generic[_T]):
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
skip_mm_cache: bool = False,
):
arrival_time = time.time()
@@ -858,7 +918,12 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
return await asyncio.gather(
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
*(
self.process_for_engine_async(
p, arrival_time, skip_mm_cache=skip_mm_cache
)
for p in tok_prompts
)
)
def render_chat(
@@ -868,6 +933,7 @@ class BaseRenderer(ABC, Generic[_T]):
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
skip_mm_cache: bool = False,
):
arrival_time = time.time()
@@ -890,7 +956,8 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
self.process_for_engine(prompt, arrival_time, skip_mm_cache=skip_mm_cache)
for prompt in tok_prompts
]
return out_conversations, eng_prompts
@@ -902,6 +969,7 @@ class BaseRenderer(ABC, Generic[_T]):
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
skip_mm_cache: bool = False,
):
arrival_time = time.time()
@@ -924,7 +992,12 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
eng_prompts = await asyncio.gather(
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
*(
self.process_for_engine_async(
p, arrival_time, skip_mm_cache=skip_mm_cache
)
for p in tok_prompts
)
)
return out_conversations, eng_prompts