[Bugfix] Offload blocking tokenizer ops to shared thread pool to unblock event loop (#34789)
Signed-off-by: Bvicii <yizhanhuang2002@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -55,6 +55,7 @@ class MockModelConfig:
|
||||
skip_tokenizer_init = False
|
||||
is_encoder_decoder: bool = False
|
||||
is_multimodal_model: bool = False
|
||||
renderer_num_workers: int = 1
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
@@ -536,6 +536,7 @@ class MockModelConfig:
|
||||
skip_tokenizer_init: bool = False
|
||||
is_encoder_decoder: bool = False
|
||||
is_multimodal_model: bool = False
|
||||
renderer_num_workers: int = 1
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
@@ -54,6 +54,7 @@ class MockModelConfig:
|
||||
skip_tokenizer_init = False
|
||||
is_encoder_decoder: bool = False
|
||||
is_multimodal_model: bool = False
|
||||
renderer_num_workers: int = 1
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
@@ -54,6 +54,7 @@ class MockModelConfig:
|
||||
skip_tokenizer_init: bool = False
|
||||
is_encoder_decoder: bool = False
|
||||
is_multimodal_model: bool = False
|
||||
renderer_num_workers: int = 1
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
@@ -38,6 +38,7 @@ class MockModelConfig:
|
||||
skip_tokenizer_init: bool = False
|
||||
is_encoder_decoder: bool = False
|
||||
is_multimodal_model: bool = False
|
||||
renderer_num_workers: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -37,6 +37,7 @@ class MockModelConfig:
|
||||
skip_tokenizer_init: bool = False
|
||||
is_encoder_decoder: bool = False
|
||||
is_multimodal_model: bool = False
|
||||
renderer_num_workers: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -291,6 +291,10 @@ class ModelConfig:
|
||||
definitions"""
|
||||
io_processor_plugin: str | None = None
|
||||
"""IOProcessor plugin name to load at model startup"""
|
||||
renderer_num_workers: int = 1
|
||||
"""Number of worker threads in the renderer thread pool. This pool
|
||||
handles async tokenization, chat template rendering, and multimodal
|
||||
preprocessing."""
|
||||
|
||||
# Pooler config
|
||||
pooler_config: PoolerConfig | None = None
|
||||
|
||||
@@ -508,6 +508,7 @@ class EngineArgs:
|
||||
MultiModalConfig.mm_encoder_attn_backend
|
||||
)
|
||||
io_processor_plugin: str | None = None
|
||||
renderer_num_workers: int = 1
|
||||
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||
video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
|
||||
mm_tensor_ipc: MMTensorIPC = MultiModalConfig.mm_tensor_ipc
|
||||
@@ -767,6 +768,10 @@ class EngineArgs:
|
||||
model_group.add_argument(
|
||||
"--io-processor-plugin", **model_kwargs["io_processor_plugin"]
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--renderer-num-workers",
|
||||
**model_kwargs["renderer_num_workers"],
|
||||
)
|
||||
|
||||
# Model loading arguments
|
||||
load_kwargs = get_kwargs(LoadConfig)
|
||||
@@ -1438,6 +1443,7 @@ class EngineArgs:
|
||||
video_pruning_rate=self.video_pruning_rate,
|
||||
mm_tensor_ipc=self.mm_tensor_ipc,
|
||||
io_processor_plugin=self.io_processor_plugin,
|
||||
renderer_num_workers=self.renderer_num_workers,
|
||||
)
|
||||
|
||||
def validate_tensorizer_args(self):
|
||||
|
||||
@@ -5,6 +5,7 @@ import copy
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Generic, overload
|
||||
|
||||
@@ -38,7 +39,10 @@ from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
|
||||
from vllm.multimodal.registry import MultiModalTimingRegistry
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
|
||||
from vllm.utils.async_utils import (
|
||||
AsyncMicrobatchTokenizer,
|
||||
make_async,
|
||||
)
|
||||
from vllm.utils.counter import AtomicCounter
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
@@ -78,11 +82,28 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Shared thread pool executor for blocking tokenizer and
|
||||
# multimodal preprocessing operations. The multimodal processor
|
||||
# receives a deep-copied tokenizer (see #36557) so it is safe to
|
||||
# run tokenization and MM preprocessing concurrently.
|
||||
pool_workers = config.model_config.renderer_num_workers
|
||||
self._executor = ThreadPoolExecutor(max_workers=pool_workers)
|
||||
|
||||
# Multimodal preprocessing is always offloaded to the thread pool
|
||||
# to keep the asyncio event loop responsive under concurrent load.
|
||||
self._mm_executor: Executor = self._executor
|
||||
|
||||
# Lazy initialization since offline LLM doesn't use async
|
||||
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
|
||||
|
||||
self.mm_processor: BaseMultiModalProcessor | None = None
|
||||
self._mm_cache_stats: MultiModalCacheStats | None = None
|
||||
self._clear_mm_cache_async = make_async(
|
||||
self.clear_mm_cache, executor=self._executor
|
||||
)
|
||||
self._process_multimodal_async = make_async(
|
||||
self._process_multimodal, executor=self._mm_executor
|
||||
)
|
||||
if config.model_config.is_multimodal_model:
|
||||
mm_processor_cache = mm_registry.processor_cache_from_config(config)
|
||||
|
||||
@@ -119,7 +140,9 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||
if self._async_tokenizer is None:
|
||||
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
|
||||
self._async_tokenizer = AsyncMicrobatchTokenizer(
|
||||
self.get_tokenizer(), executor=self._executor
|
||||
)
|
||||
|
||||
return self._async_tokenizer
|
||||
|
||||
@@ -211,11 +234,24 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
finally:
|
||||
self.clear_mm_cache()
|
||||
|
||||
async def clear_mm_cache_async(self) -> None:
|
||||
"""Serialize clear_mm_cache through the shared executor to avoid
|
||||
races with concurrent process_inputs on the mm_processor_cache."""
|
||||
await self._clear_mm_cache_async()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
mm_processor_cache = self.mm_processor_cache
|
||||
if mm_processor_cache is not None:
|
||||
mm_processor_cache.close()
|
||||
|
||||
if executor := getattr(self, "_executor", None):
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
if (
|
||||
mm_executor := getattr(self, "_mm_executor", None)
|
||||
) is not None and mm_executor is not executor:
|
||||
mm_executor.shutdown(wait=False)
|
||||
|
||||
def get_bos_token_id(self) -> int | None:
|
||||
if self.tokenizer is None:
|
||||
logger.warning_once(
|
||||
@@ -621,6 +657,9 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
self,
|
||||
prompt: TokensPrompt,
|
||||
) -> TokensInput | MultiModalInput:
|
||||
"""Process token inputs, with multimodal preprocessing offloaded
|
||||
to the shared thread pool in the async variant.
|
||||
"""
|
||||
prompt_token_ids = prompt["prompt_token_ids"]
|
||||
|
||||
engine_input: TokensInput | MultiModalInput
|
||||
@@ -670,12 +709,46 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
cache_salt=prompt.get("cache_salt"),
|
||||
)
|
||||
|
||||
async def _process_tokens_async(
|
||||
self,
|
||||
prompt: TokensPrompt,
|
||||
) -> TokensInput | MultiModalInput:
|
||||
prompt_token_ids = prompt["prompt_token_ids"]
|
||||
|
||||
engine_input: TokensInput | MultiModalInput
|
||||
if multi_modal_data := prompt.get("multi_modal_data"):
|
||||
engine_input = await self._process_multimodal_async(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=None,
|
||||
mm_uuids=prompt.get("multi_modal_uuids"),
|
||||
)
|
||||
else:
|
||||
engine_input = tokens_input(prompt_token_ids)
|
||||
|
||||
if prompt_text := prompt.get("prompt"):
|
||||
engine_input["prompt"] = prompt_text
|
||||
if cache_salt := prompt.get("cache_salt"):
|
||||
engine_input["cache_salt"] = cache_salt
|
||||
|
||||
return engine_input
|
||||
|
||||
def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput:
|
||||
if "prompt_embeds" in prompt:
|
||||
return self._process_embeds(prompt) # type: ignore[arg-type]
|
||||
|
||||
return self._process_tokens(prompt) # type: ignore[arg-type]
|
||||
|
||||
async def _process_singleton_async(
|
||||
self,
|
||||
prompt: SingletonTokPrompt,
|
||||
) -> SingletonInput:
|
||||
if "prompt_embeds" in prompt:
|
||||
return self._process_embeds(prompt) # type: ignore[arg-type]
|
||||
|
||||
return await self._process_tokens_async(prompt) # type: ignore[arg-type]
|
||||
|
||||
def _process_enc_dec(
|
||||
self,
|
||||
prompt: EncoderDecoderTokPrompt,
|
||||
@@ -699,6 +772,28 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
skip_decoder_start_token=skip_decoder_start_token,
|
||||
)
|
||||
|
||||
async def _process_enc_dec_async(
|
||||
self,
|
||||
prompt: EncoderDecoderTokPrompt,
|
||||
) -> EncoderDecoderInput:
|
||||
enc_prompt = prompt["encoder_prompt"]
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
|
||||
encoder_input, decoder_input = await asyncio.gather(
|
||||
self._process_singleton_async(enc_prompt),
|
||||
(
|
||||
asyncio.sleep(0)
|
||||
if dec_prompt is None
|
||||
else self._process_singleton_async(dec_prompt)
|
||||
),
|
||||
)
|
||||
|
||||
return build_enc_dec_input(
|
||||
encoder_input=encoder_input,
|
||||
decoder_input=decoder_input,
|
||||
decoder_start_token_id=self.get_dec_start_token_id(),
|
||||
)
|
||||
|
||||
def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput:
|
||||
engine_input: EngineInput
|
||||
if "encoder_prompt" in prompt:
|
||||
@@ -710,6 +805,21 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
return engine_input
|
||||
|
||||
async def process_for_engine_async(
|
||||
self, prompt: TokPrompt, arrival_time: float
|
||||
) -> EngineInput:
|
||||
engine_input: EngineInput
|
||||
if "encoder_prompt" in prompt:
|
||||
engine_input = await self._process_enc_dec_async(
|
||||
prompt # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
engine_input = await self._process_singleton_async(prompt)
|
||||
|
||||
engine_input["arrival_time"] = arrival_time
|
||||
|
||||
return engine_input
|
||||
|
||||
# Top-level methods
|
||||
def render_cmpl(
|
||||
self,
|
||||
@@ -747,7 +857,9 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
|
||||
return await asyncio.gather(
|
||||
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
|
||||
)
|
||||
|
||||
def render_chat(
|
||||
self,
|
||||
@@ -811,8 +923,8 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
eng_prompts = [
|
||||
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
|
||||
]
|
||||
eng_prompts = await asyncio.gather(
|
||||
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
|
||||
)
|
||||
|
||||
return out_conversations, eng_prompts
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
@@ -9,6 +10,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
||||
from vllm.utils.async_utils import make_async
|
||||
|
||||
from .base import BaseRenderer
|
||||
from .inputs import DictPrompt
|
||||
@@ -19,12 +21,25 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
tokenizer: DeepseekV32Tokenizer | None,
|
||||
) -> None:
|
||||
super().__init__(config, tokenizer)
|
||||
|
||||
self._apply_chat_template_async = make_async(
|
||||
self._apply_chat_template, executor=self._executor
|
||||
)
|
||||
|
||||
def _apply_chat_template(self, *args, **kwargs):
|
||||
return self.get_tokenizer().apply_chat_template(*args, **kwargs)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
self.model_config,
|
||||
@@ -33,7 +48,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
|
||||
mm_processor_kwargs=params.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
prompt_raw = self._apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
@@ -52,7 +67,6 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
self.model_config,
|
||||
@@ -61,7 +75,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
|
||||
mm_processor_kwargs=params.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
prompt_raw = await self._apply_chat_template_async(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
@@ -9,6 +10,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers.grok2 import Grok2Tokenizer
|
||||
from vllm.utils.async_utils import make_async
|
||||
|
||||
from .base import BaseRenderer
|
||||
from .inputs import DictPrompt
|
||||
@@ -19,12 +21,25 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
tokenizer: Grok2Tokenizer | None,
|
||||
) -> None:
|
||||
super().__init__(config, tokenizer)
|
||||
|
||||
self._apply_chat_template_async = make_async(
|
||||
self._apply_chat_template, executor=self._executor
|
||||
)
|
||||
|
||||
def _apply_chat_template(self, *args, **kwargs):
|
||||
return self.get_tokenizer().apply_chat_template(*args, **kwargs)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
self.model_config,
|
||||
@@ -33,7 +48,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
|
||||
mm_processor_kwargs=params.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
prompt_raw = self._apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
@@ -52,7 +67,6 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
self.model_config,
|
||||
@@ -61,7 +75,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
|
||||
mm_processor_kwargs=params.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
prompt_raw = await self._apply_chat_template_async(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm.logger import init_logger
|
||||
from vllm.tokenizers.hf import HfTokenizer
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils.async_utils import make_async
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
from .base import BaseRenderer
|
||||
@@ -614,6 +615,10 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
|
||||
config.model_config.hf_config, "use_unified_vision_chunk", False
|
||||
)
|
||||
|
||||
self._apply_chat_template_async = make_async(
|
||||
safe_apply_chat_template, executor=self._executor
|
||||
)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
@@ -656,10 +661,13 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
|
||||
video_placeholder = getattr(
|
||||
model_config.hf_config, "video_placeholder", None
|
||||
)
|
||||
prompt_raw = replace_vision_chunk_video_placeholder(
|
||||
prompt_raw,
|
||||
mm_data,
|
||||
video_placeholder,
|
||||
prompt_raw = cast(
|
||||
list[int],
|
||||
replace_vision_chunk_video_placeholder(
|
||||
prompt_raw,
|
||||
mm_data,
|
||||
video_placeholder,
|
||||
),
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
@@ -692,7 +700,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
|
||||
mm_processor_kwargs=params.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(
|
||||
prompt_raw = await self._apply_chat_template_async(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
@@ -710,10 +718,13 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
|
||||
video_placeholder = getattr(
|
||||
model_config.hf_config, "video_placeholder", None
|
||||
)
|
||||
prompt_raw = replace_vision_chunk_video_placeholder(
|
||||
prompt_raw,
|
||||
mm_data,
|
||||
video_placeholder,
|
||||
prompt_raw = cast(
|
||||
list[int],
|
||||
replace_vision_chunk_video_placeholder(
|
||||
prompt_raw,
|
||||
mm_data,
|
||||
video_placeholder,
|
||||
),
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
@@ -56,9 +55,8 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]):
|
||||
) -> None:
|
||||
super().__init__(config, tokenizer)
|
||||
|
||||
self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self._apply_chat_template_async = make_async(
|
||||
safe_apply_chat_template, executor=self._apply_chat_template_executor
|
||||
safe_apply_chat_template, executor=self._executor
|
||||
)
|
||||
|
||||
def render_messages(
|
||||
|
||||
@@ -34,6 +34,7 @@ class AsyncMicrobatchTokenizer:
|
||||
tokenizer,
|
||||
max_batch_size: int = 32,
|
||||
batch_wait_timeout_s: float = 0.002,
|
||||
executor: ThreadPoolExecutor | None = None,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.max_batch_size = max_batch_size
|
||||
@@ -47,7 +48,8 @@ class AsyncMicrobatchTokenizer:
|
||||
self._batcher_tasks: list[Task] = []
|
||||
|
||||
# Single-thread executor for blocking tokenizer calls.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
# Accept an external executor to serialize with other tokenizer users.
|
||||
self._executor = executor or ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# === Public async API ===
|
||||
async def __call__(self, prompt, **kwargs) -> BatchEncoding:
|
||||
|
||||
@@ -889,7 +889,7 @@ class AsyncLLM(EngineClient):
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def reset_mm_cache(self) -> None:
|
||||
self.renderer.clear_mm_cache()
|
||||
await self.renderer.clear_mm_cache_async()
|
||||
await self.engine_core.reset_mm_cache_async()
|
||||
|
||||
async def reset_prefix_cache(
|
||||
|
||||
Reference in New Issue
Block a user