[Misc] Implement TokenizerLike.convert_tokens_to_ids (#31796)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, overload
|
||||
|
||||
from transformers import BatchEncoding
|
||||
|
||||
@@ -65,6 +65,7 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
|
||||
drop_thinking = messages[-1]["role"] == "user"
|
||||
|
||||
encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
|
||||
|
||||
prompt_str = encode_messages(messages, **encode_config) # type: ignore
|
||||
|
||||
if kwargs.get("tokenize", True):
|
||||
@@ -161,6 +162,15 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
@overload
|
||||
def convert_tokens_to_ids(self, tokens: str) -> int: ...
|
||||
|
||||
@overload
|
||||
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
|
||||
|
||||
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
|
||||
return self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
return self.tokenizer.convert_tokens_to_string(tokens)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast, overload
|
||||
|
||||
from mistral_common.protocol.instruct.request import (
|
||||
ChatCompletionRequest as MistralChatCompletionRequest,
|
||||
@@ -441,6 +441,15 @@ class MistralTokenizer(TokenizerLike):
|
||||
ids, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
|
||||
@overload
|
||||
def convert_tokens_to_ids(self, tokens: str) -> int: ...
|
||||
|
||||
@overload
|
||||
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
|
||||
|
||||
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
|
||||
return self.transformers_tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
to_decode_special_tokens = {SpecialTokens.tool_calls}
|
||||
if self.is_tekken:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Protocol, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import BatchEncoding
|
||||
@@ -100,6 +100,15 @@ class TokenizerLike(Protocol):
|
||||
) -> str | list[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@overload
|
||||
def convert_tokens_to_ids(self, tokens: str) -> int: ...
|
||||
|
||||
@overload
|
||||
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
|
||||
|
||||
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user