[Misc] Refactor tokenizer interface (#29693)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-11-29 20:02:21 +08:00
committed by GitHub
parent f223ed4181
commit 34a984274e
119 changed files with 752 additions and 821 deletions

View File

@@ -39,7 +39,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import PlaceholderModule
try:
@@ -293,7 +293,7 @@ def lora_path_on_disk(lora_path: str) -> str:
# Global cache for LoRA tokenizers.
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
lora_tokenizer_cache: dict[int, TokenizerLike] = {}
def process_image(image: Any) -> Mapping[str, Any]:

View File

@@ -13,7 +13,7 @@ from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
@@ -85,7 +85,7 @@ class EngineClient(ABC):
...
@abstractmethod
async def get_tokenizer(self) -> AnyTokenizer:
async def get_tokenizer(self) -> TokenizerLike:
"""Get the tokenizer"""
...

View File

@@ -49,9 +49,9 @@ from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
from vllm.utils.func_utils import supports_kw
@@ -536,7 +536,7 @@ def resolve_hf_chat_template(
def _resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
@@ -593,7 +593,7 @@ def resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
@@ -627,7 +627,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
maximum per prompt.
"""
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike):
super().__init__()
self._model_config = model_config
@@ -1592,7 +1592,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
def parse_chat_messages(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
@@ -1624,7 +1624,7 @@ def parse_chat_messages(
def parse_chat_messages_futures(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],

View File

@@ -71,11 +71,8 @@ from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
MistralTokenizer,
get_cached_tokenizer,
)
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter
@@ -350,11 +347,11 @@ class LLM:
self.input_processor = self.llm_engine.input_processor
self.io_processor = self.llm_engine.io_processor
def get_tokenizer(self) -> AnyTokenizer:
def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer()
@deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
def set_tokenizer(self, tokenizer: TokenizerLike) -> None:
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
@@ -1244,7 +1241,7 @@ class LLM:
def _embedding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
text_1: list[str | TextPrompt | TokensPrompt],
text_2: list[str | TextPrompt | TokensPrompt],
truncate_prompt_tokens: int | None = None,
@@ -1276,7 +1273,7 @@ class LLM:
def _cross_encoding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
truncate_prompt_tokens: int | None = None,

View File

@@ -62,8 +62,9 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import (
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
MistralTokenizer,
maybe_serialize_tool_calls,
truncate_tool_call_ids,
validate_request_params,
@@ -530,7 +531,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
@@ -1296,7 +1297,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse:
created_time = int(time.time())
@@ -1624,7 +1625,7 @@ class OpenAIServingChat(OpenAIServing):
self,
logprobs: dict[int, Logprob],
top_logprobs: int | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
should_return_as_token_id: bool,
) -> list[ChatCompletionLogProb]:
return [
@@ -1648,7 +1649,7 @@ class OpenAIServingChat(OpenAIServing):
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
num_output_top_logprobs: int | None = None,
return_as_token_id: bool | None = None,
) -> ChatCompletionLogProbs:

View File

@@ -221,7 +221,7 @@ class ServingClassification(ClassificationMixin):
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
ctx: ServeContext[ClassificationRequest],
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):

View File

@@ -33,7 +33,7 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
@@ -326,7 +326,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
@@ -511,7 +511,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str,
created_time: int,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: list[CompletionResponseChoice] = []
@@ -622,7 +622,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
initial_text_offset: int = 0,
return_as_token_id: bool | None = None,
) -> CompletionLogProbs:
@@ -642,9 +642,15 @@ class OpenAIServingCompletion(OpenAIServing):
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if should_return_as_token_id:
token = f"token_id:{token_id}"
else:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)

View File

@@ -7,13 +7,14 @@ import time
import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import numpy as np
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
from typing_extensions import TypeIs
@@ -96,12 +97,12 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
@@ -184,19 +185,19 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
RequestT = TypeVar("RequestT", bound=AnyRequest)
class RequestProcessingMixin(BaseModel):
@dataclass(kw_only=True)
class RequestProcessingMixin:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
request_prompts: Sequence[RequestPrompt] | None = []
engine_prompts: list[EngineTokensPrompt] | None = []
model_config = ConfigDict(arbitrary_types_allowed=True)
request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list)
engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
class ResponseGenerationMixin(BaseModel):
@dataclass(kw_only=True)
class ResponseGenerationMixin:
"""
Mixin for response generation,
managing result generators and final batch results.
@@ -205,54 +206,38 @@ class ResponseGenerationMixin(BaseModel):
result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = Field(
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list
)
model_config = ConfigDict(arbitrary_types_allowed=True)
class ServeContext(
RequestProcessingMixin,
ResponseGenerationMixin,
BaseModel,
Generic[RequestT],
):
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
# Shared across all requests
request: RequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = Field(default_factory=lambda: int(time.time()))
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
# Shared across most requests
tokenizer: AnyTokenizer | None = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
model_config = ConfigDict(
protected_namespaces=(),
arbitrary_types_allowed=True,
)
tokenizer: TokenizerLike | None = None
ClassificationServeContext = ServeContext[ClassificationRequest]
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption
# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()
class OpenAIServing:
request_id_prefix: ClassVar[str] = """
A short string prepended to every requests ID (e.g. "embd", "classify")
@@ -281,7 +266,7 @@ class OpenAIServing:
apply_mistral_chat_template, executor=self._tokenizer_executor
)
self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
self.input_processor = self.models.input_processor
@@ -291,7 +276,7 @@ class OpenAIServing:
def _get_tool_parser(
self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
) -> Callable[[AnyTokenizer], ToolParser] | None:
) -> Callable[[TokenizerLike], ToolParser] | None:
"""Get the tool parser based on the name."""
parser = None
if not enable_auto_tools or tool_parser_name is None:
@@ -317,7 +302,7 @@ class OpenAIServing:
def _get_reasoning_parser(
self,
reasoning_parser_name: str,
) -> Callable[[AnyTokenizer], ReasoningParser] | None:
) -> Callable[[TokenizerLike], ReasoningParser] | None:
"""Get the reasoning parser based on the name."""
parser = None
if not reasoning_parser_name:
@@ -547,7 +532,7 @@ class OpenAIServing:
prompt_logprobs=None,
)
def _get_renderer(self, tokenizer: AnyTokenizer | None) -> BaseRenderer:
def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
@@ -877,7 +862,7 @@ class OpenAIServing:
self,
request: AnyRequest,
prompt: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
add_special_tokens: bool,
) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
@@ -919,7 +904,7 @@ class OpenAIServing:
self,
request: AnyRequest,
prompt_ids: list[int],
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
) -> TextTokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
@@ -1015,7 +1000,7 @@ class OpenAIServing:
async def _tokenize_prompt_input_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt_input: str | list[int],
add_special_tokens: bool = True,
) -> TextTokensPrompt:
@@ -1034,7 +1019,7 @@ class OpenAIServing:
async def _tokenize_prompt_inputs_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]:
@@ -1079,7 +1064,7 @@ class OpenAIServing:
async def _preprocess_chat(
self,
request: ChatLikeRequest | ResponsesRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
messages: list[ChatCompletionMessageParam],
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
@@ -1088,13 +1073,18 @@ class OpenAIServing:
tool_dicts: list[dict[str, Any]] | None = None,
documents: list[dict[str, str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[AnyTokenizer], ToolParser] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False,
) -> tuple[
list[ConversationMessage],
Sequence[RequestPrompt],
list[EngineTokensPrompt],
]:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
@@ -1370,9 +1360,9 @@ class OpenAIServing:
@staticmethod
def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
enable_auto_tools: bool,
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]:
function_calls = list[FunctionCall]()
@@ -1442,7 +1432,7 @@ class OpenAIServing:
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
return_as_token_id: bool = False,
) -> str:
if return_as_token_id:
@@ -1450,6 +1440,12 @@ class OpenAIServing:
if logprob.decoded_token is not None:
return logprob.decoded_token
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return tokenizer.decode(token_id)
def _is_model_supported(self, model_name: str | None) -> bool:

View File

@@ -105,7 +105,7 @@ from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
@@ -492,7 +492,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
prev_response: ResponsesResponse | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
):
if request.tools is None or (
request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
@@ -563,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int | None = None,
) -> ErrorResponse | ResponsesResponse:
@@ -675,7 +675,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
logprobs: dict[int, SampleLogprob],
top_logprobs: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> list[LogprobTopLogprob]:
"""Returns the top-k logprobs from the logprobs dictionary."""
out = []
@@ -700,7 +700,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
token_ids: Sequence[int],
logprobs: SampleLogprobs | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
top_logprobs: int | None = None,
) -> list[Logprob]:
assert logprobs is not None, "logprobs must be provided"
@@ -736,7 +736,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
token_ids: Sequence[int],
logprobs: SampleLogprobs | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
top_logprobs: int | None = None,
) -> list[response_text_delta_event.Logprob]:
lgs = self._create_response_logprobs(
@@ -763,7 +763,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
final_output: CompletionOutput,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> list[ResponseOutputItem]:
if self.reasoning_parser:
try:
@@ -1135,7 +1135,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[
@@ -1438,7 +1438,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[
@@ -1891,7 +1891,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int | None = None,
) -> AsyncGenerator[StreamingResponsesResponse, None]:

View File

@@ -36,7 +36,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.utils.async_utils import make_async, merge_async_iterators
logger = init_logger(__name__)
@@ -60,7 +60,7 @@ class ServingScores(OpenAIServing):
async def _embedding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
texts_1: list[str],
texts_2: list[str],
request: RerankRequest | ScoreRequest,
@@ -153,7 +153,7 @@ class ServingScores(OpenAIServing):
def _preprocess_score(
self,
request: RerankRequest | ScoreRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
@@ -175,7 +175,7 @@ class ServingScores(OpenAIServing):
async def _cross_encoding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
request: RerankRequest | ScoreRequest,

View File

@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@@ -170,7 +170,7 @@ class OpenAIServingTokenization(OpenAIServing):
@dataclass
class TokenizerInfo:
tokenizer: AnyTokenizer
tokenizer: TokenizerLike
chat_template: str | None
def to_dict(self) -> dict[str, Any]:

View File

@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.sampling_params import (
StructuredOutputsParams,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import import_from_path
@@ -36,7 +36,7 @@ class ToolParser:
derived classes.
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1

View File

@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class DeepSeekV31ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False

View File

@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class DeepSeekV3ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False

View File

@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Ernie45ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
"""
Ernie thinking model format:
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n

View File

@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Glm4MoeModelToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []

View File

@@ -29,7 +29,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@@ -44,7 +44,7 @@ class Granite20bFCToolParser(ToolParser):
are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.bot_token = "<function_call>"

View File

@@ -27,7 +27,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@@ -42,7 +42,7 @@ class GraniteToolParser(ToolParser):
are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>"

View File

@@ -22,18 +22,18 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
if isinstance(tokenizer, MistralTokenizer):
logger.error("Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = self.model_tokenizer.tokenizer
self.model_tokenizer = tokenizer.tokenizer
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []

View File

@@ -22,14 +22,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import consume_space
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
class HunyuanA13BToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize state for streaming mode

View File

@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.position = 0

View File

@@ -21,14 +21,13 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
class JambaToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):

View File

@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class KimiK2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []

View File

@@ -4,11 +4,11 @@
import regex as re
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
class LongcatFlashToolParser(Hermes2ProToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.tool_call_start_token: str = "<longcat_tool_call>"

View File

@@ -21,13 +21,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class MinimaxM2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.prev_tool_call_arr: list[dict] = []

View File

@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class MinimaxToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize streaming state for tracking tool call progress

View File

@@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
@@ -46,7 +46,7 @@ class MistralToolCall(ToolCall):
return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
return (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
)
@@ -61,7 +61,7 @@ class MistralToolParser(ToolParser):
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer):

View File

@@ -18,15 +18,15 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
else:
AnyTokenizer = object
TokenizerLike = object
logger = init_logger(__name__)
class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: "AnyTokenizer"):
def __init__(self, tokenizer: "TokenizerLike"):
super().__init__(tokenizer)
def extract_tool_calls(

View File

@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False

View File

@@ -23,7 +23,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@@ -1165,7 +1165,7 @@ class StreamingXMLToolCallParser:
class Qwen3XMLToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser()

View File

@@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@@ -34,7 +34,7 @@ class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# --- streaming state ---

View File

@@ -21,7 +21,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
@@ -41,7 +41,7 @@ class Step3ToolParser(ToolParser):
TOOL_SEP = "<tool_sep>"
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.position = 0
# Explicit state flags for robust streaming

View File

@@ -21,14 +21,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
class xLAMToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize state for streaming mode

View File

@@ -16,7 +16,7 @@ from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
@@ -85,7 +85,7 @@ class BaseRenderer(ABC):
def __init__(
self,
model_config: ModelConfig,
tokenizer: AnyTokenizer | None = None,
tokenizer: TokenizerLike | None = None,
):
super().__init__()
self.model_config = model_config
@@ -200,8 +200,8 @@ class CompletionRenderer(BaseRenderer):
def __init__(
self,
model_config: ModelConfig,
tokenizer: AnyTokenizer | None = None,
async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer]
tokenizer: TokenizerLike | None = None,
async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer]
| None = None,
):
super().__init__(model_config, tokenizer)
@@ -373,7 +373,7 @@ class CompletionRenderer(BaseRenderer):
return async_tokenizer
tokenizer = self.tokenizer
if self.tokenizer is None:
if tokenizer is None:
raise ValueError("No tokenizer available for text input processing")
if self.async_tokenizer_pool is None:

View File

@@ -19,11 +19,7 @@ from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from vllm.transformers_utils.tokenizer import TokenizerLike
ScoreContentPartParam: TypeAlias = (
ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam
@@ -45,7 +41,7 @@ class ScoreMultiModalParam(TypedDict, total=False):
def _cosine_similarity(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
tokenizer: TokenizerLike,
embed_1: list[PoolingRequestOutput],
embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]:
@@ -93,7 +89,7 @@ def parse_score_data(
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> tuple[str, str, MultiModalDataDict | None]:
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
@@ -118,12 +114,14 @@ def _parse_score_content(
mm_tracker: BaseMultiModalItemTracker,
) -> _ContentPart | None:
if isinstance(data, str):
data = ChatCompletionContentPartTextParam(type="text", text=data)
part = ChatCompletionContentPartTextParam(type="text", text=data)
else:
part = data
mm_parser = mm_tracker.create_parser()
parse_res = _parse_chat_message_content_part(
data,
part,
mm_parser,
wrap_dicts=False,
interleave_strings=False,
@@ -181,7 +179,7 @@ def post_process_tokens(
def get_score_prompt(
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,

View File

@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__)

View File

@@ -17,7 +17,7 @@ from vllm.multimodal.inputs import (
MultiModalUUIDDict,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
@@ -46,7 +46,7 @@ class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None:
@@ -59,7 +59,7 @@ class InputPreprocessor:
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
def get_tokenizer(self) -> AnyTokenizer:
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init` is True"
@@ -228,11 +228,11 @@ class InputPreprocessor:
return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(self) -> AnyTokenizer:
def _get_mm_tokenizer(self) -> TokenizerLike:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
return cast(TokenizerLike, object()) # Dummy
tokenizer = self.get_tokenizer()
return tokenizer

View File

@@ -5,7 +5,7 @@ from typing import TypeAlias
import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
LogitsProcessor: TypeAlias = (
Callable[[list[int], torch.Tensor], torch.Tensor]
@@ -19,7 +19,7 @@ to sample from."""
def get_bad_words_logits_processors(
bad_words: list[str], tokenizer: AnyTokenizer
bad_words: list[str], tokenizer: TokenizerLike
) -> list[LogitsProcessor]:
bad_words_ids: list[list[int]] = list()

View File

@@ -28,7 +28,7 @@ from vllm.multimodal.processing import (
PromptUpdate,
PromptUpdateDetails,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from .intern_vit import InternVisionModel
from .internvl import (
@@ -241,7 +241,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,

View File

@@ -50,7 +50,7 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_num_threads
@@ -347,7 +347,7 @@ class BaseInternVLProcessor(ABC):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
@@ -561,7 +561,7 @@ class InternVLProcessor(BaseInternVLProcessor):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,

View File

@@ -73,9 +73,9 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
cached_tokenizer_from_config,
encode_tokens,
)
@@ -284,7 +284,7 @@ class BaseNanoNemotronVLProcessor(ABC):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*args,
max_num_tiles: int | None = None,
**kwargs,
@@ -434,7 +434,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
max_num_tiles: int | None = None,
min_dynamic_patch: int | None = None,
@@ -645,7 +645,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
tokens_per_frame: list[int],
frames_indices: list[int],
frame_duration_ms: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
img_start_token_ids: list[int],
img_end_token_ids: list[int],
img_context_token_ids: list[int],
@@ -670,7 +670,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
tokens_per_frame (list[int]): number of tokens per frame
frames_indices (list[int]): frame indices
frame_duration_ms (int): duration of each frame in milliseconds
tokenizer (AnyTokenizer): tokenizer to use for tokenizing frame separators
tokenizer (TokenizerLike): tokenizer to use for tokenizing frame separators
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens

View File

@@ -34,8 +34,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.processing import PromptUpdateDetails
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_image_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import (
MultiModalEmbeddings,
@@ -203,7 +203,7 @@ class NemotronVLProcessor(InternVLProcessor):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
image_processor: BaseImageProcessorFast,
*,
min_dynamic_patch: int | None = None,

View File

@@ -31,7 +31,7 @@ from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from .qwen2_5_vl import (
Qwen2_5_VisionTransformer as OpenCUAVisionTransformer,
@@ -79,7 +79,7 @@ class OpenCUAProcessor(Qwen2VLProcessor):
def __init__(
self,
vision_config: dict,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
**kwargs,
):
image_processor = Qwen2VLImageProcessor(**vision_config)

View File

@@ -59,10 +59,8 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (
MistralTokenizer,
cached_tokenizer_from_config,
)
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP

View File

@@ -91,7 +91,7 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
@@ -1533,7 +1533,7 @@ class Tarsier2Processor(Qwen2VLProcessor):
def __init__(
self,
vision_config: dict,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
**kwargs,
):
self.image_processor = Tarsier2ImageProcessor(**vision_config)

View File

@@ -47,7 +47,7 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@@ -282,7 +282,7 @@ class SkyworkR1VProcessor:
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,

View File

@@ -43,8 +43,8 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@@ -321,7 +321,7 @@ class Step3VLProcessor:
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> None:
super().__init__()

View File

@@ -51,10 +51,8 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (
MistralTokenizer,
cached_tokenizer_from_config,
)
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix

View File

@@ -23,8 +23,9 @@ import torch
from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
@@ -76,7 +77,7 @@ PromptSeq: TypeAlias = str | list[int]
@lru_cache(maxsize=2048)
def _cached_encode(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
text: str,
*,
add_special_tokens: bool | None = None,
@@ -86,7 +87,7 @@ def _cached_encode(
@lru_cache(maxsize=2048)
def _cached_decode(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
token_ids: tuple[int, ...],
*,
skip_special_tokens: bool | None = None,
@@ -96,14 +97,14 @@ def _cached_decode(
)
def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str:
def _seq2text(tokenizer: TokenizerLike, seq: PromptSeq) -> str:
if isinstance(seq, str):
return seq
return _cached_decode(tokenizer, tuple(seq))
def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]:
if isinstance(seq, str):
return _cached_encode(tokenizer, seq, add_special_tokens=False)
@@ -113,7 +114,7 @@ def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
class _GetMatchIndex(Protocol):
def __call__(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt: PromptSeq,
start_idx: int = 0,
) -> int | None: ...
@@ -143,7 +144,7 @@ class PromptIndexTargets:
"""
def get_match_index(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt: PromptSeq,
start_idx: int = 0,
) -> int | None:
@@ -199,7 +200,7 @@ class PromptUpdateDetails(Generic[_S]):
full: _S
"""The full content."""
is_embed: Callable[[AnyTokenizer, PromptSeq], torch.Tensor] | None = None
is_embed: Callable[[TokenizerLike, PromptSeq], torch.Tensor] | None = None
"""
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
return a boolean mask of shape `(len(full),)` indicating which positions
@@ -220,7 +221,7 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S,
embed_text: str,
) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
embed_token_ids = encode_tokens(tokenizer, embed_text)
token_ids = _seq2tokens(tokenizer, full)
@@ -236,7 +237,7 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S,
embed_token_id: int,
) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
token_ids = _seq2tokens(tokenizer, full)
return torch.tensor(token_ids) == embed_token_id
@@ -522,7 +523,7 @@ class ResolvedPromptUpdate:
def iter_token_matches(
self,
prompt: list[int],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
@@ -544,7 +545,7 @@ class ResolvedPromptUpdate:
def iter_text_matches(
self,
prompt: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
@@ -566,7 +567,7 @@ class ResolvedPromptUpdate:
def iter_matches(
self,
prompt: list[int] | str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
@@ -675,7 +676,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
def _find_matches(
prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
prev_end_idx: int = 0,
current_result: "MultiModalPromptUpdatesApplyResult",
@@ -740,7 +741,7 @@ def _all_items_found(
def _apply_matches(
prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
@@ -806,7 +807,7 @@ def _apply_matches(
def apply_token_matches(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
"""
Apply the updates in `mm_prompt_updates` to `prompt`.
@@ -823,7 +824,7 @@ def apply_token_matches(
def apply_text_matches(
prompt: str,
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
"""
Apply the updates in `mm_prompt_updates` to `prompt`.
@@ -840,7 +841,7 @@ def apply_text_matches(
def _iter_placeholders(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> Iterable[PlaceholderFeaturesInfo]:
"""
Yield each set of placeholder tokens found in `prompt`.
@@ -909,7 +910,7 @@ def _iter_placeholders(
def find_mm_placeholders(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
return dict(full_groupby_modality(it))
@@ -930,7 +931,7 @@ class InputProcessingContext:
model_config: ModelConfig
"""The configuration of the model."""
tokenizer: AnyTokenizer
tokenizer: TokenizerLike
"""The tokenizer used to tokenize the inputs."""
@overload
@@ -1146,7 +1147,7 @@ class BaseProcessingInfo:
def model_id(self) -> str:
return self.ctx.model_config.model
def get_tokenizer(self) -> AnyTokenizer:
def get_tokenizer(self) -> TokenizerLike:
return self.ctx.tokenizer
def get_hf_config(self) -> PretrainedConfig:

View File

@@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache
from .processing import (
@@ -231,17 +232,20 @@ class MultiModalRegistry:
def _create_processing_ctx(
self,
model_config: "ModelConfig",
tokenizer: AnyTokenizer | None = None,
tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext:
if tokenizer is None and not model_config.skip_tokenizer_init:
if model_config.skip_tokenizer_init:
tokenizer = cast(TokenizerLike, object())
elif tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer)
def _create_processing_info(
self,
model_config: "ModelConfig",
*,
tokenizer: AnyTokenizer | None = None,
tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
@@ -252,7 +256,7 @@ class MultiModalRegistry:
self,
model_config: "ModelConfig",
*,
tokenizer: AnyTokenizer | None = None,
tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""

View File

@@ -19,12 +19,12 @@ if TYPE_CHECKING:
DeltaMessage,
ResponsesRequest,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
else:
ChatCompletionRequest = Any
DeltaMessage = Any
ResponsesRequest = Any
AnyTokenizer = Any
TokenizerLike = Any
logger = init_logger(__name__)
@@ -37,7 +37,7 @@ class ReasoningParser:
It is used to extract reasoning content from the model output.
"""
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
self.model_tokenizer = tokenizer
@cached_property

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
from vllm.entrypoints.openai.protocol import DeltaMessage
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.entrypoints.openai.protocol import (
@@ -43,7 +43,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
"""The token that ends reasoning content."""
raise NotImplementedError
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
if not self.model_tokenizer:

View File

@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.logger import init_logger
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@@ -37,7 +37,7 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
Reasoning parser for MiniMax M2 model.
"""
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.end_token_id = self.vocab.get("</think>")

View File

@@ -6,7 +6,7 @@ from functools import cached_property
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
logger = init_logger(__name__)

View File

@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
import regex as re
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
@@ -220,7 +220,7 @@ class Olmo3ReasoningParser(ReasoningParser):
token is missing from generation.
"""
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.think_start = r"<think>"

View File

@@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.v1.serial_utils import PydanticMsgspecMixin
logger = init_logger(__name__)
@@ -477,7 +477,7 @@ class SamplingParams(
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
if not self.bad_words:
return
self._bad_words_token_ids = []

View File

@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .mistral import MistralTokenizer
from .protocol import TokenizerLike
from .registry import TokenizerRegistry
__all__ = ["TokenizerLike", "MistralTokenizer", "TokenizerRegistry"]

View File

@@ -4,7 +4,8 @@
from typing import TYPE_CHECKING, Any, cast
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_base import TokenizerBase
from .protocol import TokenizerLike
if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import (
@@ -163,7 +164,7 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
return tokenizer.unk_id
class MistralTokenizer(TokenizerBase):
class MistralTokenizer(TokenizerLike):
def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.sentencepiece import (
@@ -270,14 +271,6 @@ class MistralTokenizer(TokenizerBase):
def eos_token_id(self) -> int:
return self.tokenizer.eos_id
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
return self.transformers_tokenizer.pad_token
@property
def is_fast(self) -> bool:
return True
@@ -292,11 +285,14 @@ class MistralTokenizer(TokenizerBase):
@property
def truncation_side(self) -> str:
raise NotImplementedError()
return self.transformers_tokenizer.truncation_side
def _is_special_token_id(self, token_id: int) -> bool:
return token_id in self._special_token_ids_set
def __hash__(self) -> int:
return hash(id(self))
def __len__(self) -> int:
return self.vocab_size
@@ -341,17 +337,6 @@ class MistralTokenizer(TokenizerBase):
# Mistral tokenizers have no added vocabulary
return {}
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: int | None = None,
) -> list[int]:
# Mistral Tokenizers should not add special tokens
return self.transformers_tokenizer.encode(
text, add_special_tokens=False, truncation=truncation, max_length=max_length
)
def encode(
self,
text: str,

105
vllm/tokenizers/protocol.py Normal file
View File

@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Protocol
from typing_extensions import Self
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
class TokenizerLike(Protocol):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
/,
*,
revision: str | None = None,
) -> Self:
raise NotImplementedError
@property
def all_special_tokens(self) -> list[str]:
raise NotImplementedError
@property
def all_special_ids(self) -> list[int]:
raise NotImplementedError
@property
def bos_token_id(self) -> int:
raise NotImplementedError
@property
def eos_token_id(self) -> int:
raise NotImplementedError
@property
def is_fast(self) -> bool:
raise NotImplementedError
@property
def vocab_size(self) -> int:
raise NotImplementedError
@property
def max_token_id(self) -> int:
raise NotImplementedError
@property
def truncation_side(self) -> str:
raise NotImplementedError
def __hash__(self) -> int:
return hash(id(self))
def __len__(self) -> int:
return self.vocab_size
def __call__(
self,
text: str | list[str] | list[int],
text_pair: str | None = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: int | None = None,
):
raise NotImplementedError
def get_vocab(self) -> dict[str, int]:
raise NotImplementedError
def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError
def encode(
self,
text: str,
truncation: bool | None = None,
max_length: int | None = None,
add_special_tokens: bool | None = None,
) -> list[int]:
raise NotImplementedError
def apply_chat_template(
self,
messages: list["ChatCompletionMessageParam"],
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> list[int]:
raise NotImplementedError
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
raise NotImplementedError
def convert_ids_to_tokens(
self,
ids: list[int],
skip_special_tokens: bool = True,
) -> list[str]:
raise NotImplementedError

View File

@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from .protocol import TokenizerLike
class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: dict[str, tuple[str, str]] = {}
@staticmethod
def register(name: str, module: str, class_name: str) -> None:
TokenizerRegistry.REGISTRY[name] = (module, class_name)
@staticmethod
def get_tokenizer(
tokenizer_name: str,
*args,
**kwargs,
) -> "TokenizerLike":
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
if tokenizer_cls is None:
raise ValueError(f"Tokenizer {tokenizer_name} not found.")
tokenizer_module = importlib.import_module(tokenizer_cls[0])
class_ = getattr(tokenizer_module, tokenizer_cls[1])
return class_.from_pretrained(*args, **kwargs)

View File

@@ -26,8 +26,9 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.repo_utils import (
from .config_parser_base import ConfigParserBase
from .repo_utils import (
_get_hf_token,
file_or_path_exists,
get_hf_file_to_dict,
@@ -35,7 +36,7 @@ from vllm.transformers_utils.repo_utils import (
try_get_local_file,
with_retry,
)
from vllm.transformers_utils.utils import (
from .utils import (
check_gguf_file,
is_gguf,
is_remote_gguf,

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
def _replace_none_with_empty(tokens: list[str | None]):
@@ -12,7 +12,7 @@ def _replace_none_with_empty(tokens: list[str | None]):
def _convert_tokens_to_string_with_added_encoders(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
output_tokens: list[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
@@ -57,7 +57,7 @@ INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt_ids: list[int],
skip_special_tokens: bool = False,
) -> tuple[list[str], int, int]:
@@ -81,7 +81,7 @@ def convert_prompt_ids_to_tokens(
def convert_ids_list_to_tokens(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
token_ids: list[int],
) -> list[str]:
"""Detokenize the input ids individually.
@@ -108,7 +108,7 @@ def convert_ids_list_to_tokens(
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
all_input_ids: list[int],
prev_tokens: list[str] | None,
prefix_offset: int,

View File

@@ -9,7 +9,8 @@ from gguf.constants import Keys, VisionProjectorType
from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig
from vllm.logger import init_logger
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
from .repo_utils import list_filtered_repo_files
logger = init_logger(__name__)

View File

@@ -5,41 +5,48 @@ import contextlib
import copy
import importlib.util
import os
import warnings
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias
from typing import TYPE_CHECKING, Any
import huggingface_hub
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing_extensions import assert_never
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
from vllm.transformers_utils.gguf_utils import get_gguf_file_path_from_hf
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import (
check_gguf_file,
is_gguf,
is_remote_gguf,
split_remote_gguf,
)
from vllm.tokenizers import MistralTokenizer, TokenizerLike, TokenizerRegistry
from .config import get_sentence_transformer_tokenizer_config
from .gguf_utils import get_gguf_file_path_from_hf
from .repo_utils import list_filtered_repo_files
from .utils import check_gguf_file, is_gguf, is_remote_gguf, split_remote_gguf
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.transformers_utils.tokenizer_base import TokenizerBase
else:
ModelConfig = Any
TokenizerBase = Any
logger = init_logger(__name__)
AnyTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast | TokenizerBase
def __getattr__(name: str):
if name == "AnyTokenizer":
warnings.warn(
"`vllm.transformers_utils.tokenizer.AnyTokenizer` has been moved to "
"`vllm.tokenizers.TokenizerLike`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
return TokenizerLike
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def decode_tokens(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
token_ids: list[int],
*,
skip_special_tokens: bool | None = None,
@@ -58,7 +65,7 @@ def decode_tokens(
def encode_tokens(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
text: str,
*,
truncation: bool | None = None,
@@ -86,7 +93,7 @@ def encode_tokens(
return tokenizer.encode(text, **kw_args)
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
def get_cached_tokenizer(tokenizer: TokenizerLike) -> TokenizerLike:
"""
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown.
@@ -144,7 +151,7 @@ def get_tokenizer(
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> AnyTokenizer:
) -> TokenizerLike:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
@@ -206,15 +213,13 @@ def get_tokenizer(
if len(files_list) > 0:
tokenizer_mode = "mistral"
tokenizer: AnyTokenizer
tokenizer: TokenizerLike
if tokenizer_mode == "mistral":
logger.debug_once(f"Loading MistralTokenizer from {tokenizer_name}")
tokenizer = MistralTokenizer.from_pretrained(
str(tokenizer_name), revision=revision
)
elif tokenizer_mode == "custom":
from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
logger.debug_once(f"Loading CustomTokenizer from {tokenizer_name}")
tokenizer = TokenizerRegistry.get_tokenizer(
str(tokenizer_name),
@@ -260,12 +265,13 @@ def get_tokenizer(
if isinstance(encoder_config, dict) and encoder_config.get(
"do_lower_case", False
):
assert isinstance(tokenizer, PreTrainedTokenizerBase)
special_tokens_map = {
k: v.lower() for k, v in tokenizer.special_tokens_map.items()
}
tokenizer.add_special_tokens(special_tokens_map)
if not isinstance(tokenizer, PreTrainedTokenizerFast):
if not tokenizer.is_fast:
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
@@ -279,7 +285,7 @@ cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(
model_config: ModelConfig,
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_tokenizer(
@@ -291,7 +297,7 @@ def cached_tokenizer_from_config(
)
def init_tokenizer_from_configs(model_config: ModelConfig):
def init_tokenizer_from_configs(model_config: "ModelConfig"):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"

View File

@@ -1,150 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
import warnings
class TokenizerBase(ABC):
@property
@abstractmethod
def all_special_tokens(self) -> list[str]:
raise NotImplementedError()
def __getattr__(name: str):
if name == "TokenizerBase":
from vllm.tokenizers import TokenizerLike
@property
@abstractmethod
def all_special_ids(self) -> list[int]:
raise NotImplementedError()
warnings.warn(
"`vllm.transformers_utils.tokenizer_base.TokenizerBase` has been "
"moved to `vllm.tokenizers.TokenizerLike`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
@property
@abstractmethod
def bos_token_id(self) -> int:
raise NotImplementedError()
return TokenizerLike
if name == "TokenizerRegistry":
from vllm.tokenizers import TokenizerRegistry
@property
@abstractmethod
def eos_token_id(self) -> int:
raise NotImplementedError()
warnings.warn(
"`vllm.transformers_utils.tokenizer_base.TokenizerRegistry` has been "
"moved to `vllm.tokenizers.TokenizerRegistry`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
@property
@abstractmethod
def sep_token(self) -> str:
raise NotImplementedError()
return TokenizerRegistry
@property
@abstractmethod
def pad_token(self) -> str:
raise NotImplementedError()
@property
@abstractmethod
def is_fast(self) -> bool:
raise NotImplementedError()
@property
@abstractmethod
def vocab_size(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def max_token_id(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def truncation_side(self) -> str:
raise NotImplementedError()
def __len__(self) -> int:
return self.vocab_size
@abstractmethod
def __call__(
self,
text: str | list[str] | list[int],
text_pair: str | None = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: int | None = None,
):
raise NotImplementedError()
@abstractmethod
def get_vocab(self) -> dict[str, int]:
raise NotImplementedError()
@abstractmethod
def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError()
@abstractmethod
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: int | None = None,
) -> list[int]:
raise NotImplementedError()
@abstractmethod
def encode(
self,
text: str,
truncation: bool | None = None,
max_length: int | None = None,
add_special_tokens: bool | None = None,
) -> list[int]:
raise NotImplementedError()
@abstractmethod
def apply_chat_template(
self,
messages: list["ChatCompletionMessageParam"],
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> list[int]:
raise NotImplementedError()
@abstractmethod
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError()
@abstractmethod
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
@abstractmethod
def convert_ids_to_tokens(
self,
ids: list[int],
skip_special_tokens: bool = True,
) -> list[str]:
raise NotImplementedError()
class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: dict[str, tuple[str, str]] = {}
@staticmethod
def register(name: str, module: str, class_name: str) -> None:
TokenizerRegistry.REGISTRY[name] = (module, class_name)
@staticmethod
def get_tokenizer(
tokenizer_name: str,
*args,
**kwargs,
) -> TokenizerBase:
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
if tokenizer_cls is None:
raise ValueError(f"Tokenizer {tokenizer_name} not found.")
tokenizer_module = importlib.import_module(tokenizer_cls[0])
class_ = getattr(tokenizer_module, tokenizer_cls[1])
return class_.from_pretrained(*args, **kwargs)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -1,16 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .mistral import (
MistralTokenizer,
maybe_serialize_tool_calls,
truncate_tool_call_ids,
validate_request_params,
)
__all__ = [
"MistralTokenizer",
"maybe_serialize_tool_calls",
"truncate_tool_call_ids",
"validate_request_params",
]

View File

@@ -26,9 +26,10 @@ from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
@@ -120,9 +121,10 @@ class AsyncLLM(EngineClient):
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
stream_interval = self.vllm_config.scheduler_config.stream_interval
self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
self.tokenizer,
log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval,
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
@@ -703,17 +705,17 @@ class AsyncLLM(EngineClient):
raise EngineGenerateError() from e
@property
def tokenizer(self) -> AnyTokenizer | None:
def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
self.input_processor.tokenizer = tokenizer
async def get_tokenizer(self) -> AnyTokenizer:
async def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"Unable to get tokenizer because skip_tokenizer_init is True"
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return self.tokenizer

View File

@@ -10,7 +10,7 @@ from transformers import PreTrainedTokenizerFast
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer,
TokenizerLike,
convert_prompt_ids_to_tokens,
detokenize_incrementally,
)
@@ -45,7 +45,7 @@ class IncrementalDetokenizer:
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":
assert request.sampling_params is not None
@@ -256,7 +256,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
def __init__(self, tokenizer: TokenizerLike, request: EngineCoreRequest):
super().__init__(request)
self.tokenizer = tokenizer

View File

@@ -19,8 +19,7 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
@@ -40,7 +39,7 @@ class InputProcessor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
self.vllm_config = vllm_config
@@ -62,11 +61,11 @@ class InputProcessor:
)
@property
def tokenizer(self) -> AnyTokenizer | None:
def tokenizer(self) -> TokenizerLike | None:
return self.input_preprocessor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
self.input_preprocessor.tokenizer = tokenizer
def _validate_logprobs(

View File

@@ -23,8 +23,9 @@ from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
@@ -95,9 +96,10 @@ class LLMEngine:
)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
stream_interval = self.vllm_config.scheduler_config.stream_interval
self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
self.tokenizer,
log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval,
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
@@ -350,17 +352,17 @@ class LLMEngine:
return get_metrics_snapshot()
@property
def tokenizer(self) -> AnyTokenizer | None:
def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
self.input_processor.tokenizer = tokenizer
def get_tokenizer(self) -> AnyTokenizer:
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"Unable to get tokenizer because skip_tokenizer_init is True"
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return self.tokenizer

View File

@@ -13,7 +13,7 @@ from vllm.logprobs import (
create_sample_logprobs,
)
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer,
TokenizerLike,
convert_ids_list_to_tokens,
)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
@@ -28,7 +28,7 @@ NONES = itertools.repeat(None)
class LogprobsProcessor:
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: AnyTokenizer | None
tokenizer: TokenizerLike | None
# Logprobs for this request
logprobs: SampleLogprobs | None
@@ -40,7 +40,7 @@ class LogprobsProcessor:
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
request: EngineCoreRequest,
) -> "LogprobsProcessor":
sampling_params = request.sampling_params

View File

@@ -15,8 +15,8 @@ from vllm.outputs import (
RequestOutput,
)
from vllm.sampling_params import RequestOutputKind
from vllm.tokenizers import TokenizerLike
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
@@ -139,7 +139,7 @@ class RequestState:
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
request: EngineCoreRequest,
prompt: str | None,
parent_req: ParentRequest | None,
@@ -341,7 +341,10 @@ class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
def __init__(
self, tokenizer: AnyTokenizer, log_stats: bool, stream_interval: int = 1
self,
tokenizer: TokenizerLike | None,
log_stats: bool,
stream_interval: int = 1,
):
self.log_stats = log_stats
self.tokenizer = tokenizer

View File

@@ -10,10 +10,10 @@ if TYPE_CHECKING:
import torch
from vllm.config import VllmConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
else:
VllmConfig = object
AnyTokenizer = object
TokenizerLike = object
class StructuredOutputOptions(enum.Enum):
@@ -100,7 +100,7 @@ class StructuredOutputBackend(ABC):
"""Engine-level backend for structured output requests."""
vllm_config: VllmConfig
tokenizer: AnyTokenizer
tokenizer: TokenizerLike
vocab_size: int
@abstractmethod

View File

@@ -10,7 +10,7 @@ import torch
import vllm.envs
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from vllm.utils.import_utils import LazyLoader
from vllm.v1.structured_output.backend_types import (
StructuredOutputBackend,

View File

@@ -24,7 +24,7 @@ if TYPE_CHECKING:
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
import xgrammar as xgr
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.v1.worker.gpu_input_batch import InputBatch
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
@@ -36,7 +36,7 @@ else:
"transformers.models.gpt2.tokenization_gpt2",
)
AnyTokenizer = object
TokenizerLike = object
SchedulerOutput = object
InputBatch = object
@@ -195,7 +195,7 @@ re_replacement_seq = re.compile(r"^.{0,6}<7D>+.{0,6}$")
def _reduced_vocabulary(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
eos_token_id: int,
) -> dict[bytes, list[int]]:
"""Create a map from vocabulary tokens to lists of equivalent token ids.
@@ -222,7 +222,7 @@ def _reduced_vocabulary(
vocabulary: dict[bytes, list[int]] = {}
empty_token_ids: list[int] = []
for token, token_idx in tokenizer.get_vocab().items():
if token in tokenizer.all_special_tokens: # type: ignore
if token in tokenizer.all_special_tokens:
continue
token_str = convert_token_to_string(token)
@@ -261,7 +261,7 @@ def _reduced_vocabulary(
return vocabulary
def get_outlines_vocabulary(tokenizer: AnyTokenizer) -> oc.Vocabulary:
def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
"""Get the `Vocabulary` object for a given tokenizer."""
if hasattr(tokenizer, "_outlines_vocabulary"):
return tokenizer._outlines_vocabulary # type: ignore