[Misc] Implement TokenizerLike.convert_tokens_to_ids (#31796)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-06 20:08:22 +08:00
committed by GitHub
parent bf0f3a4638
commit 6444824873
3 changed files with 31 additions and 3 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from typing import Any
from typing import Any, overload
from transformers import BatchEncoding
@@ -65,6 +65,7 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
drop_thinking = messages[-1]["role"] == "user"
encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
prompt_str = encode_messages(messages, **encode_config) # type: ignore
if kwargs.get("tokenize", True):
@@ -161,6 +162,15 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
add_special_tokens=add_special_tokens,
)
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
return self.tokenizer.convert_tokens_to_ids(tokens)
def convert_tokens_to_string(self, tokens: list[str]) -> str:
return self.tokenizer.convert_tokens_to_string(tokens)

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, cast, overload
from mistral_common.protocol.instruct.request import (
ChatCompletionRequest as MistralChatCompletionRequest,
@@ -441,6 +441,15 @@ class MistralTokenizer(TokenizerLike):
ids, skip_special_tokens=skip_special_tokens
)
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
return self.transformers_tokenizer.convert_tokens_to_ids(tokens)
def convert_tokens_to_string(self, tokens: list[str]) -> str:
to_decode_special_tokens = {SpecialTokens.tool_calls}
if self.is_tekken:

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol, overload
if TYPE_CHECKING:
from transformers import BatchEncoding
@@ -100,6 +100,15 @@ class TokenizerLike(Protocol):
) -> str | list[int]:
raise NotImplementedError
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
raise NotImplementedError
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError