[Bugfix] Use dedicated MM processor cache in /tokenize to prevent sender-cache pollution (#38545)
Signed-off-by: Sergey Zinchenko <sergey.zinchenko.rnd@gmail.com>
This commit is contained in:
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Regression test: calling ``/tokenize`` with multimodal data followed by
|
||||
``/v1/chat/completions`` with the same data must not cause an error.
|
||||
|
||||
Ensures that the ``/tokenize`` endpoint does not pollute internal caches
|
||||
(e.g. multimodal feature caches) and that a subsequent
|
||||
``/v1/chat/completions`` request with the same multimodal payload
|
||||
completes successfully.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--max-num-seqs",
|
||||
"5",
|
||||
"--enforce-eager",
|
||||
"--limit-mm-per-prompt",
|
||||
json.dumps({"image": 1}),
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenize_then_chat_completion_with_image(
|
||||
client: openai.AsyncOpenAI,
|
||||
server: RemoteOpenAIServer,
|
||||
local_asset_server,
|
||||
):
|
||||
"""Tokenize a multimodal message, then send the same message to chat
|
||||
completions. The chat completion must succeed (not 500)."""
|
||||
|
||||
image_url = local_asset_server.url_for("stop_sign.jpg")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": "Describe this image briefly."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
tok_resp = requests.post(
|
||||
server.url_for("tokenize"),
|
||||
json={"model": MODEL_NAME, "messages": messages},
|
||||
)
|
||||
tok_resp.raise_for_status()
|
||||
tok_data = tok_resp.json()
|
||||
assert tok_data["count"] > 0, "Tokenization must return tokens"
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert chat_completion.choices[0].message.content, (
|
||||
"Chat completion must produce non-empty content after tokenize"
|
||||
)
|
||||
@@ -451,6 +451,8 @@ class OpenAIServingRender:
|
||||
request: Any,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
|
||||
prompt_embeds: bytes | list[bytes] | None,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> list[EngineInput]:
|
||||
"""Copied from OpenAIServing._preprocess_completion."""
|
||||
prompts = list[SingletonPrompt | bytes]()
|
||||
@@ -458,12 +460,14 @@ class OpenAIServingRender:
|
||||
prompts.extend(prompt_to_seq(prompt_embeds))
|
||||
if prompt_input is not None:
|
||||
prompts.extend(prompt_to_seq(prompt_input))
|
||||
return await self.preprocess_cmpl(request, prompts)
|
||||
return await self.preprocess_cmpl(request, prompts, skip_mm_cache=skip_mm_cache)
|
||||
|
||||
async def preprocess_cmpl(
|
||||
self,
|
||||
request: Any,
|
||||
prompts: Sequence[PromptType | bytes],
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> list[EngineInput]:
|
||||
"""Copied from OpenAIServing._preprocess_cmpl."""
|
||||
renderer = self.renderer
|
||||
@@ -487,6 +491,7 @@ class OpenAIServingRender:
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
skip_mm_cache=skip_mm_cache,
|
||||
)
|
||||
|
||||
async def preprocess_chat(
|
||||
@@ -498,6 +503,8 @@ class OpenAIServingRender:
|
||||
default_template_kwargs: dict[str, Any] | None,
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
tool_parser: type[ToolParser] | None = None,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> tuple[list[ConversationMessage], list[EngineInput]]:
|
||||
"""Copied from OpenAIServing._preprocess_chat."""
|
||||
renderer = self.renderer
|
||||
@@ -529,6 +536,7 @@ class OpenAIServingRender:
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
skip_mm_cache=skip_mm_cache,
|
||||
)
|
||||
|
||||
# tool parsing is done only if a tool_parser has been set and if
|
||||
|
||||
@@ -86,12 +86,14 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=self.default_chat_template_kwargs,
|
||||
tool_dicts=tool_dicts,
|
||||
skip_mm_cache=True,
|
||||
)
|
||||
else:
|
||||
engine_inputs = await self.openai_serving_render.preprocess_completion(
|
||||
request,
|
||||
prompt_input=request.prompt,
|
||||
prompt_embeds=None,
|
||||
skip_mm_cache=True,
|
||||
)
|
||||
|
||||
input_ids: list[int] = []
|
||||
|
||||
@@ -97,6 +97,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
|
||||
|
||||
self.mm_processor: BaseMultiModalProcessor | None = None
|
||||
self._readonly_mm_processor: BaseMultiModalProcessor | None = None
|
||||
self._mm_cache_stats: MultiModalCacheStats | None = None
|
||||
self._clear_mm_cache_async = make_async(
|
||||
self.clear_mm_cache, executor=self._executor
|
||||
@@ -124,6 +125,19 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
if mm_processor_cache:
|
||||
self._mm_cache_stats = MultiModalCacheStats()
|
||||
|
||||
# A second processor with its own processor-only cache.
|
||||
# Used by the tokenize endpoint so that tokenize-only
|
||||
# requests don't pollute the sender cache.
|
||||
ro_cache = mm_registry.processor_only_cache_from_config(config)
|
||||
if ro_cache is not None:
|
||||
ro_tokenizer = copy.deepcopy(tokenizer)
|
||||
with set_default_torch_num_threads():
|
||||
self._readonly_mm_processor = mm_registry.create_processor(
|
||||
config.model_config,
|
||||
tokenizer=ro_tokenizer,
|
||||
cache=ro_cache,
|
||||
)
|
||||
|
||||
# This is used to generate internal request ID for MM processing
|
||||
# It has no relation to the request ID for engine core
|
||||
self._mm_req_counter = AtomicCounter()
|
||||
@@ -625,10 +639,15 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
mm_uuids: MultiModalUUIDDict | None,
|
||||
mm_processor_kwargs: Mapping[str, object] | None,
|
||||
tokenization_kwargs: dict[str, Any] | None,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> "MultiModalInput":
|
||||
mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}"
|
||||
|
||||
mm_processor = self.get_mm_processor()
|
||||
if skip_mm_cache and self._readonly_mm_processor is not None:
|
||||
mm_processor = self._readonly_mm_processor
|
||||
else:
|
||||
mm_processor = self.get_mm_processor()
|
||||
|
||||
mm_data_items = mm_processor.info.parse_mm_data(mm_data)
|
||||
mm_uuid_items = parse_mm_uuids(mm_uuids)
|
||||
@@ -656,6 +675,8 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
def _process_tokens(
|
||||
self,
|
||||
prompt: TokensPrompt,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> TokensInput | MultiModalInput:
|
||||
"""Process token inputs, with multimodal preprocessing offloaded
|
||||
to the shared thread pool in the async variant.
|
||||
@@ -670,6 +691,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=None, # Tokenization already done in Step 2
|
||||
mm_uuids=prompt.get("multi_modal_uuids"),
|
||||
skip_mm_cache=skip_mm_cache,
|
||||
)
|
||||
else:
|
||||
engine_input = tokens_input(prompt_token_ids)
|
||||
@@ -712,6 +734,8 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
async def _process_tokens_async(
|
||||
self,
|
||||
prompt: TokensPrompt,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> TokensInput | MultiModalInput:
|
||||
prompt_token_ids = prompt["prompt_token_ids"]
|
||||
|
||||
@@ -723,6 +747,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=None,
|
||||
mm_uuids=prompt.get("multi_modal_uuids"),
|
||||
skip_mm_cache=skip_mm_cache,
|
||||
)
|
||||
else:
|
||||
engine_input = tokens_input(prompt_token_ids)
|
||||
@@ -734,24 +759,33 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
return engine_input
|
||||
|
||||
def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput:
|
||||
if "prompt_embeds" in prompt:
|
||||
return self._process_embeds(prompt) # type: ignore[arg-type]
|
||||
|
||||
return self._process_tokens(prompt) # type: ignore[arg-type]
|
||||
|
||||
async def _process_singleton_async(
|
||||
def _process_singleton(
|
||||
self,
|
||||
prompt: SingletonTokPrompt,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> SingletonInput:
|
||||
if "prompt_embeds" in prompt:
|
||||
return self._process_embeds(prompt) # type: ignore[arg-type]
|
||||
|
||||
return await self._process_tokens_async(prompt) # type: ignore[arg-type]
|
||||
return self._process_tokens(prompt, skip_mm_cache=skip_mm_cache) # type: ignore[arg-type]
|
||||
|
||||
async def _process_singleton_async(
|
||||
self,
|
||||
prompt: SingletonTokPrompt,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> SingletonInput:
|
||||
if "prompt_embeds" in prompt:
|
||||
return self._process_embeds(prompt) # type: ignore[arg-type]
|
||||
|
||||
return await self._process_tokens_async(prompt, skip_mm_cache=skip_mm_cache) # type: ignore[arg-type]
|
||||
|
||||
def _process_enc_dec(
|
||||
self,
|
||||
prompt: EncoderDecoderTokPrompt,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> EncoderDecoderInput:
|
||||
enc_prompt = prompt["encoder_prompt"]
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
@@ -764,9 +798,13 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
skip_decoder_start_token = self.mm_processor.skip_decoder_start_token
|
||||
|
||||
return build_enc_dec_input(
|
||||
encoder_input=self._process_singleton(enc_prompt),
|
||||
encoder_input=self._process_singleton(
|
||||
enc_prompt, skip_mm_cache=skip_mm_cache
|
||||
),
|
||||
decoder_input=(
|
||||
None if dec_prompt is None else self._process_singleton(dec_prompt)
|
||||
None
|
||||
if dec_prompt is None
|
||||
else self._process_singleton(dec_prompt, skip_mm_cache=skip_mm_cache)
|
||||
),
|
||||
decoder_start_token_id=self.get_dec_start_token_id(),
|
||||
skip_decoder_start_token=skip_decoder_start_token,
|
||||
@@ -775,16 +813,20 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
async def _process_enc_dec_async(
|
||||
self,
|
||||
prompt: EncoderDecoderTokPrompt,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> EncoderDecoderInput:
|
||||
enc_prompt = prompt["encoder_prompt"]
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
|
||||
encoder_input, decoder_input = await asyncio.gather(
|
||||
self._process_singleton_async(enc_prompt),
|
||||
self._process_singleton_async(enc_prompt, skip_mm_cache=skip_mm_cache),
|
||||
(
|
||||
asyncio.sleep(0)
|
||||
if dec_prompt is None
|
||||
else self._process_singleton_async(dec_prompt)
|
||||
else self._process_singleton_async(
|
||||
dec_prompt, skip_mm_cache=skip_mm_cache
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -794,27 +836,40 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
decoder_start_token_id=self.get_dec_start_token_id(),
|
||||
)
|
||||
|
||||
def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput:
|
||||
def process_for_engine(
|
||||
self,
|
||||
prompt: TokPrompt,
|
||||
arrival_time: float,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> EngineInput:
|
||||
engine_input: EngineInput
|
||||
if "encoder_prompt" in prompt:
|
||||
engine_input = self._process_enc_dec(prompt) # type: ignore[arg-type]
|
||||
engine_input = self._process_enc_dec(prompt, skip_mm_cache=skip_mm_cache) # type: ignore[arg-type]
|
||||
else:
|
||||
engine_input = self._process_singleton(prompt)
|
||||
engine_input = self._process_singleton(prompt, skip_mm_cache=skip_mm_cache)
|
||||
|
||||
engine_input["arrival_time"] = arrival_time
|
||||
|
||||
return engine_input
|
||||
|
||||
async def process_for_engine_async(
|
||||
self, prompt: TokPrompt, arrival_time: float
|
||||
self,
|
||||
prompt: TokPrompt,
|
||||
arrival_time: float,
|
||||
*,
|
||||
skip_mm_cache: bool = False,
|
||||
) -> EngineInput:
|
||||
engine_input: EngineInput
|
||||
if "encoder_prompt" in prompt:
|
||||
engine_input = await self._process_enc_dec_async(
|
||||
prompt # type: ignore[arg-type]
|
||||
prompt, # type: ignore[arg-type]
|
||||
skip_mm_cache=skip_mm_cache,
|
||||
)
|
||||
else:
|
||||
engine_input = await self._process_singleton_async(prompt)
|
||||
engine_input = await self._process_singleton_async(
|
||||
prompt, skip_mm_cache=skip_mm_cache
|
||||
)
|
||||
|
||||
engine_input["arrival_time"] = arrival_time
|
||||
|
||||
@@ -827,6 +882,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
tok_params: TokenizeParams | None = None,
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
skip_mm_cache: bool = False,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
@@ -838,7 +894,10 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
|
||||
return [
|
||||
self.process_for_engine(prompt, arrival_time, skip_mm_cache=skip_mm_cache)
|
||||
for prompt in tok_prompts
|
||||
]
|
||||
|
||||
async def render_cmpl_async(
|
||||
self,
|
||||
@@ -846,6 +905,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
tok_params: TokenizeParams | None = None,
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
skip_mm_cache: bool = False,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
@@ -858,7 +918,12 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
return await asyncio.gather(
|
||||
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
|
||||
*(
|
||||
self.process_for_engine_async(
|
||||
p, arrival_time, skip_mm_cache=skip_mm_cache
|
||||
)
|
||||
for p in tok_prompts
|
||||
)
|
||||
)
|
||||
|
||||
def render_chat(
|
||||
@@ -868,6 +933,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
tok_params: TokenizeParams | None = None,
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
skip_mm_cache: bool = False,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
@@ -890,7 +956,8 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
eng_prompts = [
|
||||
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
|
||||
self.process_for_engine(prompt, arrival_time, skip_mm_cache=skip_mm_cache)
|
||||
for prompt in tok_prompts
|
||||
]
|
||||
|
||||
return out_conversations, eng_prompts
|
||||
@@ -902,6 +969,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
tok_params: TokenizeParams | None = None,
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
skip_mm_cache: bool = False,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
@@ -924,7 +992,12 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
eng_prompts = await asyncio.gather(
|
||||
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
|
||||
*(
|
||||
self.process_for_engine_async(
|
||||
p, arrival_time, skip_mm_cache=skip_mm_cache
|
||||
)
|
||||
for p in tok_prompts
|
||||
)
|
||||
)
|
||||
|
||||
return out_conversations, eng_prompts
|
||||
|
||||
Reference in New Issue
Block a user