[Bugfix] Offload blocking tokenizer ops to shared thread pool to unblock event loop (#34789)

Signed-off-by: Bvicii <yizhanhuang2002@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Bvicii
2026-03-26 22:17:00 -07:00
committed by GitHub
parent d86060122a
commit 999dfc1622
15 changed files with 195 additions and 28 deletions

View File

@@ -55,6 +55,7 @@ class MockModelConfig:
skip_tokenizer_init = False skip_tokenizer_init = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False is_multimodal_model: bool = False
renderer_num_workers: int = 1
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}

View File

@@ -536,6 +536,7 @@ class MockModelConfig:
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False is_multimodal_model: bool = False
renderer_num_workers: int = 1
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}

View File

@@ -54,6 +54,7 @@ class MockModelConfig:
skip_tokenizer_init = False skip_tokenizer_init = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False is_multimodal_model: bool = False
renderer_num_workers: int = 1
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}

View File

@@ -54,6 +54,7 @@ class MockModelConfig:
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False is_multimodal_model: bool = False
renderer_num_workers: int = 1
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}

View File

@@ -38,6 +38,7 @@ class MockModelConfig:
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False is_multimodal_model: bool = False
renderer_num_workers: int = 1
@dataclass @dataclass

View File

@@ -37,6 +37,7 @@ class MockModelConfig:
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False is_multimodal_model: bool = False
renderer_num_workers: int = 1
@dataclass @dataclass

View File

@@ -291,6 +291,10 @@ class ModelConfig:
definitions""" definitions"""
io_processor_plugin: str | None = None io_processor_plugin: str | None = None
"""IOProcessor plugin name to load at model startup""" """IOProcessor plugin name to load at model startup"""
renderer_num_workers: int = 1
"""Number of worker threads in the renderer thread pool. This pool
handles async tokenization, chat template rendering, and multimodal
preprocessing."""
# Pooler config # Pooler config
pooler_config: PoolerConfig | None = None pooler_config: PoolerConfig | None = None

View File

@@ -508,6 +508,7 @@ class EngineArgs:
MultiModalConfig.mm_encoder_attn_backend MultiModalConfig.mm_encoder_attn_backend
) )
io_processor_plugin: str | None = None io_processor_plugin: str | None = None
renderer_num_workers: int = 1
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
mm_tensor_ipc: MMTensorIPC = MultiModalConfig.mm_tensor_ipc mm_tensor_ipc: MMTensorIPC = MultiModalConfig.mm_tensor_ipc
@@ -767,6 +768,10 @@ class EngineArgs:
model_group.add_argument( model_group.add_argument(
"--io-processor-plugin", **model_kwargs["io_processor_plugin"] "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
) )
model_group.add_argument(
"--renderer-num-workers",
**model_kwargs["renderer_num_workers"],
)
# Model loading arguments # Model loading arguments
load_kwargs = get_kwargs(LoadConfig) load_kwargs = get_kwargs(LoadConfig)
@@ -1438,6 +1443,7 @@ class EngineArgs:
video_pruning_rate=self.video_pruning_rate, video_pruning_rate=self.video_pruning_rate,
mm_tensor_ipc=self.mm_tensor_ipc, mm_tensor_ipc=self.mm_tensor_ipc,
io_processor_plugin=self.io_processor_plugin, io_processor_plugin=self.io_processor_plugin,
renderer_num_workers=self.renderer_num_workers,
) )
def validate_tensorizer_args(self): def validate_tensorizer_args(self):

View File

@@ -5,6 +5,7 @@ import copy
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload from typing import TYPE_CHECKING, Any, Generic, overload
@@ -38,7 +39,10 @@ from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
from vllm.multimodal.registry import MultiModalTimingRegistry from vllm.multimodal.registry import MultiModalTimingRegistry
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
make_async,
)
from vllm.utils.counter import AtomicCounter from vllm.utils.counter import AtomicCounter
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.metrics.stats import MultiModalCacheStats
@@ -78,11 +82,28 @@ class BaseRenderer(ABC, Generic[_T]):
self.tokenizer = tokenizer self.tokenizer = tokenizer
# Shared thread pool executor for blocking tokenizer and
# multimodal preprocessing operations. The multimodal processor
# receives a deep-copied tokenizer (see #36557) so it is safe to
# run tokenization and MM preprocessing concurrently.
pool_workers = config.model_config.renderer_num_workers
self._executor = ThreadPoolExecutor(max_workers=pool_workers)
# Multimodal preprocessing is always offloaded to the thread pool
# to keep the asyncio event loop responsive under concurrent load.
self._mm_executor: Executor = self._executor
# Lazy initialization since offline LLM doesn't use async # Lazy initialization since offline LLM doesn't use async
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
self.mm_processor: BaseMultiModalProcessor | None = None self.mm_processor: BaseMultiModalProcessor | None = None
self._mm_cache_stats: MultiModalCacheStats | None = None self._mm_cache_stats: MultiModalCacheStats | None = None
self._clear_mm_cache_async = make_async(
self.clear_mm_cache, executor=self._executor
)
self._process_multimodal_async = make_async(
self._process_multimodal, executor=self._mm_executor
)
if config.model_config.is_multimodal_model: if config.model_config.is_multimodal_model:
mm_processor_cache = mm_registry.processor_cache_from_config(config) mm_processor_cache = mm_registry.processor_cache_from_config(config)
@@ -119,7 +140,9 @@ class BaseRenderer(ABC, Generic[_T]):
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
if self._async_tokenizer is None: if self._async_tokenizer is None:
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer()) self._async_tokenizer = AsyncMicrobatchTokenizer(
self.get_tokenizer(), executor=self._executor
)
return self._async_tokenizer return self._async_tokenizer
@@ -211,11 +234,24 @@ class BaseRenderer(ABC, Generic[_T]):
finally: finally:
self.clear_mm_cache() self.clear_mm_cache()
async def clear_mm_cache_async(self) -> None:
"""Serialize clear_mm_cache through the shared executor to avoid
races with concurrent process_inputs on the mm_processor_cache."""
await self._clear_mm_cache_async()
def shutdown(self) -> None: def shutdown(self) -> None:
mm_processor_cache = self.mm_processor_cache mm_processor_cache = self.mm_processor_cache
if mm_processor_cache is not None: if mm_processor_cache is not None:
mm_processor_cache.close() mm_processor_cache.close()
if executor := getattr(self, "_executor", None):
executor.shutdown(wait=False)
if (
mm_executor := getattr(self, "_mm_executor", None)
) is not None and mm_executor is not executor:
mm_executor.shutdown(wait=False)
def get_bos_token_id(self) -> int | None: def get_bos_token_id(self) -> int | None:
if self.tokenizer is None: if self.tokenizer is None:
logger.warning_once( logger.warning_once(
@@ -621,6 +657,9 @@ class BaseRenderer(ABC, Generic[_T]):
self, self,
prompt: TokensPrompt, prompt: TokensPrompt,
) -> TokensInput | MultiModalInput: ) -> TokensInput | MultiModalInput:
"""Process token inputs, with multimodal preprocessing offloaded
to the shared thread pool in the async variant.
"""
prompt_token_ids = prompt["prompt_token_ids"] prompt_token_ids = prompt["prompt_token_ids"]
engine_input: TokensInput | MultiModalInput engine_input: TokensInput | MultiModalInput
@@ -670,12 +709,46 @@ class BaseRenderer(ABC, Generic[_T]):
cache_salt=prompt.get("cache_salt"), cache_salt=prompt.get("cache_salt"),
) )
async def _process_tokens_async(
self,
prompt: TokensPrompt,
) -> TokensInput | MultiModalInput:
prompt_token_ids = prompt["prompt_token_ids"]
engine_input: TokensInput | MultiModalInput
if multi_modal_data := prompt.get("multi_modal_data"):
engine_input = await self._process_multimodal_async(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None,
mm_uuids=prompt.get("multi_modal_uuids"),
)
else:
engine_input = tokens_input(prompt_token_ids)
if prompt_text := prompt.get("prompt"):
engine_input["prompt"] = prompt_text
if cache_salt := prompt.get("cache_salt"):
engine_input["cache_salt"] = cache_salt
return engine_input
def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput: def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput:
if "prompt_embeds" in prompt: if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type] return self._process_embeds(prompt) # type: ignore[arg-type]
return self._process_tokens(prompt) # type: ignore[arg-type] return self._process_tokens(prompt) # type: ignore[arg-type]
async def _process_singleton_async(
self,
prompt: SingletonTokPrompt,
) -> SingletonInput:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
return await self._process_tokens_async(prompt) # type: ignore[arg-type]
def _process_enc_dec( def _process_enc_dec(
self, self,
prompt: EncoderDecoderTokPrompt, prompt: EncoderDecoderTokPrompt,
@@ -699,6 +772,28 @@ class BaseRenderer(ABC, Generic[_T]):
skip_decoder_start_token=skip_decoder_start_token, skip_decoder_start_token=skip_decoder_start_token,
) )
async def _process_enc_dec_async(
self,
prompt: EncoderDecoderTokPrompt,
) -> EncoderDecoderInput:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
encoder_input, decoder_input = await asyncio.gather(
self._process_singleton_async(enc_prompt),
(
asyncio.sleep(0)
if dec_prompt is None
else self._process_singleton_async(dec_prompt)
),
)
return build_enc_dec_input(
encoder_input=encoder_input,
decoder_input=decoder_input,
decoder_start_token_id=self.get_dec_start_token_id(),
)
def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput: def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput:
engine_input: EngineInput engine_input: EngineInput
if "encoder_prompt" in prompt: if "encoder_prompt" in prompt:
@@ -710,6 +805,21 @@ class BaseRenderer(ABC, Generic[_T]):
return engine_input return engine_input
async def process_for_engine_async(
self, prompt: TokPrompt, arrival_time: float
) -> EngineInput:
engine_input: EngineInput
if "encoder_prompt" in prompt:
engine_input = await self._process_enc_dec_async(
prompt # type: ignore[arg-type]
)
else:
engine_input = await self._process_singleton_async(prompt)
engine_input["arrival_time"] = arrival_time
return engine_input
# Top-level methods # Top-level methods
def render_cmpl( def render_cmpl(
self, self,
@@ -747,7 +857,9 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts] return await asyncio.gather(
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
)
def render_chat( def render_chat(
self, self,
@@ -811,8 +923,8 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
eng_prompts = [ eng_prompts = await asyncio.gather(
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts *(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
] )
return out_conversations, eng_prompts return out_conversations, eng_prompts

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ConversationMessage, ConversationMessage,
@@ -9,6 +10,7 @@ from vllm.entrypoints.chat_utils import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from vllm.utils.async_utils import make_async
from .base import BaseRenderer from .base import BaseRenderer
from .inputs import DictPrompt from .inputs import DictPrompt
@@ -19,12 +21,25 @@ logger = init_logger(__name__)
class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]): class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
def __init__(
self,
config: VllmConfig,
tokenizer: DeepseekV32Tokenizer | None,
) -> None:
super().__init__(config, tokenizer)
self._apply_chat_template_async = make_async(
self._apply_chat_template, executor=self._executor
)
def _apply_chat_template(self, *args, **kwargs):
return self.get_tokenizer().apply_chat_template(*args, **kwargs)
def render_messages( def render_messages(
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
messages, messages,
self.model_config, self.model_config,
@@ -33,7 +48,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
mm_processor_kwargs=params.mm_processor_kwargs, mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = self._apply_chat_template(
conversation=conversation, conversation=conversation,
messages=messages, messages=messages,
**params.get_apply_chat_template_kwargs(), **params.get_apply_chat_template_kwargs(),
@@ -52,7 +67,6 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async( conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages, messages,
self.model_config, self.model_config,
@@ -61,7 +75,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
mm_processor_kwargs=params.mm_processor_kwargs, mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = await self._apply_chat_template_async(
conversation=conversation, conversation=conversation,
messages=messages, messages=messages,
**params.get_apply_chat_template_kwargs(), **params.get_apply_chat_template_kwargs(),

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ConversationMessage, ConversationMessage,
@@ -9,6 +10,7 @@ from vllm.entrypoints.chat_utils import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers.grok2 import Grok2Tokenizer from vllm.tokenizers.grok2 import Grok2Tokenizer
from vllm.utils.async_utils import make_async
from .base import BaseRenderer from .base import BaseRenderer
from .inputs import DictPrompt from .inputs import DictPrompt
@@ -19,12 +21,25 @@ logger = init_logger(__name__)
class Grok2Renderer(BaseRenderer[Grok2Tokenizer]): class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
def __init__(
self,
config: VllmConfig,
tokenizer: Grok2Tokenizer | None,
) -> None:
super().__init__(config, tokenizer)
self._apply_chat_template_async = make_async(
self._apply_chat_template, executor=self._executor
)
def _apply_chat_template(self, *args, **kwargs):
return self.get_tokenizer().apply_chat_template(*args, **kwargs)
def render_messages( def render_messages(
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
messages, messages,
self.model_config, self.model_config,
@@ -33,7 +48,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
mm_processor_kwargs=params.mm_processor_kwargs, mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = self._apply_chat_template(
conversation=conversation, conversation=conversation,
messages=messages, messages=messages,
**params.get_apply_chat_template_kwargs(), **params.get_apply_chat_template_kwargs(),
@@ -52,7 +67,6 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async( conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages, messages,
self.model_config, self.model_config,
@@ -61,7 +75,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
mm_processor_kwargs=params.mm_processor_kwargs, mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = await self._apply_chat_template_async(
conversation=conversation, conversation=conversation,
messages=messages, messages=messages,
**params.get_apply_chat_template_kwargs(), **params.get_apply_chat_template_kwargs(),

View File

@@ -30,6 +30,7 @@ from vllm.logger import init_logger
from vllm.tokenizers.hf import HfTokenizer from vllm.tokenizers.hf import HfTokenizer
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.async_utils import make_async
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
from .base import BaseRenderer from .base import BaseRenderer
@@ -614,6 +615,10 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
config.model_config.hf_config, "use_unified_vision_chunk", False config.model_config.hf_config, "use_unified_vision_chunk", False
) )
self._apply_chat_template_async = make_async(
safe_apply_chat_template, executor=self._executor
)
def render_messages( def render_messages(
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
@@ -656,10 +661,13 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
video_placeholder = getattr( video_placeholder = getattr(
model_config.hf_config, "video_placeholder", None model_config.hf_config, "video_placeholder", None
) )
prompt_raw = replace_vision_chunk_video_placeholder( prompt_raw = cast(
list[int],
replace_vision_chunk_video_placeholder(
prompt_raw, prompt_raw,
mm_data, mm_data,
video_placeholder, video_placeholder,
),
) )
prompt = parse_dec_only_prompt(prompt_raw) prompt = parse_dec_only_prompt(prompt_raw)
@@ -692,7 +700,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
mm_processor_kwargs=params.mm_processor_kwargs, mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = safe_apply_chat_template( prompt_raw = await self._apply_chat_template_async(
model_config, model_config,
tokenizer, tokenizer,
conversation, conversation,
@@ -710,10 +718,13 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
video_placeholder = getattr( video_placeholder = getattr(
model_config.hf_config, "video_placeholder", None model_config.hf_config, "video_placeholder", None
) )
prompt_raw = replace_vision_chunk_video_placeholder( prompt_raw = cast(
list[int],
replace_vision_chunk_video_placeholder(
prompt_raw, prompt_raw,
mm_data, mm_data,
video_placeholder, video_placeholder,
),
) )
prompt = parse_dec_only_prompt(prompt_raw) prompt = parse_dec_only_prompt(prompt_raw)

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import ThreadPoolExecutor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
@@ -56,9 +55,8 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]):
) -> None: ) -> None:
super().__init__(config, tokenizer) super().__init__(config, tokenizer)
self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
self._apply_chat_template_async = make_async( self._apply_chat_template_async = make_async(
safe_apply_chat_template, executor=self._apply_chat_template_executor safe_apply_chat_template, executor=self._executor
) )
def render_messages( def render_messages(

View File

@@ -34,6 +34,7 @@ class AsyncMicrobatchTokenizer:
tokenizer, tokenizer,
max_batch_size: int = 32, max_batch_size: int = 32,
batch_wait_timeout_s: float = 0.002, batch_wait_timeout_s: float = 0.002,
executor: ThreadPoolExecutor | None = None,
) -> None: ) -> None:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
@@ -47,7 +48,8 @@ class AsyncMicrobatchTokenizer:
self._batcher_tasks: list[Task] = [] self._batcher_tasks: list[Task] = []
# Single-thread executor for blocking tokenizer calls. # Single-thread executor for blocking tokenizer calls.
self._executor = ThreadPoolExecutor(max_workers=1) # Accept an external executor to serialize with other tokenizer users.
self._executor = executor or ThreadPoolExecutor(max_workers=1)
# === Public async API === # === Public async API ===
async def __call__(self, prompt, **kwargs) -> BatchEncoding: async def __call__(self, prompt, **kwargs) -> BatchEncoding:

View File

@@ -889,7 +889,7 @@ class AsyncLLM(EngineClient):
await asyncio.gather(*coros) await asyncio.gather(*coros)
async def reset_mm_cache(self) -> None: async def reset_mm_cache(self) -> None:
self.renderer.clear_mm_cache() await self.renderer.clear_mm_cache_async()
await self.engine_core.reset_mm_cache_async() await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache( async def reset_prefix_cache(