From 6444824873711eeeb68ded97cad79394bd051ed1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 6 Jan 2026 20:08:22 +0800 Subject: [PATCH] [Misc] Implement `TokenizerLike.convert_tokens_to_ids` (#31796) Signed-off-by: DarkLight1337 --- vllm/tokenizers/deepseek_v32.py | 12 +++++++++++- vllm/tokenizers/mistral.py | 11 ++++++++++- vllm/tokenizers/protocol.py | 11 ++++++++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/vllm/tokenizers/deepseek_v32.py b/vllm/tokenizers/deepseek_v32.py index d519b61dd..4402054c9 100644 --- a/vllm/tokenizers/deepseek_v32.py +++ b/vllm/tokenizers/deepseek_v32.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path -from typing import Any +from typing import Any, overload from transformers import BatchEncoding @@ -65,6 +65,7 @@ class DeepseekV32Tokenizer(CachedHfTokenizer): drop_thinking = messages[-1]["role"] == "user" encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking) + prompt_str = encode_messages(messages, **encode_config) # type: ignore if kwargs.get("tokenize", True): @@ -161,6 +162,15 @@ class DeepseekV32Tokenizer(CachedHfTokenizer): add_special_tokens=add_special_tokens, ) + @overload + def convert_tokens_to_ids(self, tokens: str) -> int: ... + + @overload + def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ... + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + return self.tokenizer.convert_tokens_to_ids(tokens) + def convert_tokens_to_string(self, tokens: list[str]) -> str: return self.tokenizer.convert_tokens_to_string(tokens) diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index 090286228..35a11e95b 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, cast, overload from mistral_common.protocol.instruct.request import ( ChatCompletionRequest as MistralChatCompletionRequest, @@ -441,6 +441,15 @@ class MistralTokenizer(TokenizerLike): ids, skip_special_tokens=skip_special_tokens ) + @overload + def convert_tokens_to_ids(self, tokens: str) -> int: ... + + @overload + def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ... + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + return self.transformers_tokenizer.convert_tokens_to_ids(tokens) + def convert_tokens_to_string(self, tokens: list[str]) -> str: to_decode_special_tokens = {SpecialTokens.tool_calls} if self.is_tekken: diff --git a/vllm/tokenizers/protocol.py b/vllm/tokenizers/protocol.py index 28754f9e1..21e5b3a7b 100644 --- a/vllm/tokenizers/protocol.py +++ b/vllm/tokenizers/protocol.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, overload if TYPE_CHECKING: from transformers import BatchEncoding @@ -100,6 +100,15 @@ class TokenizerLike(Protocol): ) -> str | list[int]: raise NotImplementedError + @overload + def convert_tokens_to_ids(self, tokens: str) -> int: ... + + @overload + def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ... + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + raise NotImplementedError + def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError