diff --git a/tests/tokenizers_/test_basic.py b/tests/tokenizers_/test_basic.py index b5c26a659..99f68ecd0 100644 --- a/tests/tokenizers_/test_basic.py +++ b/tests/tokenizers_/test_basic.py @@ -11,6 +11,7 @@ from transformers import ( from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.grok2 import Grok2Tokenizer +from vllm.tokenizers.hf import HfTokenizer from vllm.tokenizers.mistral import MistralTokenizer @@ -42,6 +43,13 @@ def test_tokenizer_like_protocol(): assert isinstance(tokenizer, Grok2Tokenizer) _assert_tokenizer_like(tokenizer) + tokenizer = get_tokenizer("deepseek-ai/DeepSeek-V3", tokenizer_mode="deepseek_v32") + assert isinstance(tokenizer, HfTokenizer) + # Verify it's a fast tokenizer (required for FastIncrementalDetokenizer) + assert isinstance(tokenizer, PreTrainedTokenizerFast) + assert "DSV32" in tokenizer.__class__.__name__ + _assert_tokenizer_like(tokenizer) + @pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"]) def test_tokenizer_revision(tokenizer_name: str): diff --git a/vllm/renderers/deepseek_v32.py b/vllm/renderers/deepseek_v32.py index f83edd16f..d10a596b2 100644 --- a/vllm/renderers/deepseek_v32.py +++ b/vllm/renderers/deepseek_v32.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer +from ..tokenizers.hf import HfTokenizer from .inputs import DictPrompt from .inputs.preprocess import parse_dec_only_prompt from .params import ChatParams @@ -48,10 +49,10 @@ class DeepseekV32Renderer(BaseRenderer): self._tokenizer = tokenizer @property - def tokenizer(self) -> DeepseekV32Tokenizer | None: + def tokenizer(self) -> HfTokenizer | None: return self._tokenizer - def get_tokenizer(self) -> DeepseekV32Tokenizer: + def get_tokenizer(self) -> HfTokenizer: tokenizer = self.tokenizer if tokenizer is None: raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") diff --git a/vllm/tokenizers/deepseek_v32.py b/vllm/tokenizers/deepseek_v32.py index cb0ffe73a..28071ef69 100644 --- a/vllm/tokenizers/deepseek_v32.py +++ b/vllm/tokenizers/deepseek_v32.py @@ -1,191 +1,89 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +from typing import Any -from pathlib import Path -from typing import Any, overload - -from transformers import BatchEncoding +from transformers import AutoTokenizer from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from . import TokenizerLike from .deepseek_v32_encoding import encode_messages -from .hf import CachedHfTokenizer -from .protocol import TokenizerLike +from .hf import HfTokenizer, get_cached_tokenizer -class DeepseekV32Tokenizer(CachedHfTokenizer): - @classmethod - def from_pretrained( - cls, - path_or_repo_id: str | Path, - *args, - trust_remote_code: bool = False, - revision: str | None = None, - download_dir: str | None = None, - **kwargs, - ) -> "TokenizerLike": - tokenizer = super().from_pretrained( - path_or_repo_id, - *args, - trust_remote_code=trust_remote_code, - revision=revision, - download_dir=download_dir, +def get_deepseek_v32_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: + """ + Wraps a tokenizer to use the custom DeepSeek V3.2 chat template encoding. + """ + dsv32_tokenizer = copy.copy(tokenizer) + + added_vocab = tokenizer.get_added_vocab() + added_vocab_size = len(added_vocab) + tokenizer_vocab_size = tokenizer.vocab_size + + class _DeepseekV32Tokenizer(tokenizer.__class__): # type: ignore + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: list[dict[str, Any]] | None = None, **kwargs, - ) - return DeepseekV32Tokenizer(tokenizer) + ) -> str | list[int]: + thinking = kwargs.get("thinking", False) + enable_thinking = kwargs.get("enable_thinking", False) + thinking = thinking or enable_thinking + thinking_mode = "thinking" + if not thinking: + thinking_mode = "chat" + conversation = kwargs.get("conversation", messages) + messages = conversation.copy() + if tools is not None and len(tools) > 0: + messages.insert(0, {"role": "system"}) + messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key] - def __init__(self, tokenizer: TokenizerLike) -> None: - super().__init__() + # Historical reasoning content is dropped when a new user message + # is introduced + drop_thinking = messages[-1]["role"] == "user" - self.tokenizer = tokenizer - self.name_or_path = getattr(tokenizer, "name_or_path", "") - - self._added_vocab = self.tokenizer.get_added_vocab() - self._added_vocab_size = len(self._added_vocab) - - def apply_chat_template( - self, - messages: list["ChatCompletionMessageParam"], - tools: list[dict[str, Any]] | None = None, - **kwargs, - ) -> str | list[int]: - thinking = kwargs.get("thinking", False) - enable_thinking = kwargs.get("enable_thinking", False) - thinking = thinking or enable_thinking - thinking_mode = "thinking" - if not thinking: - thinking_mode = "chat" - conversation = kwargs.get("conversation", messages) - messages = conversation.copy() - if tools is not None and len(tools) > 0: - messages.insert(0, {"role": "system"}) - messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key] - - # Historical reasoning content is dropped when a new user message is introduced - drop_thinking = messages[-1]["role"] == "user" - - encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking) - - prompt_str = encode_messages(messages, **encode_config) # type: ignore - - if kwargs.get("tokenize", True): - tokenizer_kwargs = { - k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs - } - return self.encode( - prompt_str, - add_special_tokens=False, - **tokenizer_kwargs, + encode_config = dict( + thinking_mode=thinking_mode, drop_thinking=drop_thinking ) - return prompt_str + prompt_str = encode_messages(messages, **encode_config) # type: ignore - def num_special_tokens_to_add(self) -> int: - return len(self.encode("")) + if kwargs.get("tokenize", True): + tokenizer_kwargs = { + k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs + } + return self.encode( + prompt_str, + add_special_tokens=False, + **tokenizer_kwargs, + ) - @property - def all_special_tokens(self) -> list[str]: - return self.tokenizer.all_special_tokens + return prompt_str - @property - def all_special_ids(self) -> list[int]: - return self.tokenizer.all_special_ids + def num_special_tokens_to_add(self) -> int: + return len(self.encode("")) - @property - def bos_token_id(self) -> int: - return self.tokenizer.bos_token_id + def __len__(self) -> int: + # is an added token in DeepseekV32 tokenizer + return tokenizer_vocab_size + added_vocab_size - @property - def eos_token_id(self) -> int: - return self.tokenizer.eos_token_id + def get_added_vocab(self) -> dict[str, int]: + return added_vocab.copy() - @property - def pad_token_id(self) -> int: - return self.tokenizer.pad_token_id + def __reduce__(self): + return get_deepseek_v32_tokenizer, (tokenizer,) - @property - def is_fast(self) -> bool: - return self.tokenizer.is_fast + _DeepseekV32Tokenizer.__name__ = f"DSV32{tokenizer.__class__.__name__}" - @property - def vocab_size(self) -> int: - return self.tokenizer.vocab_size + dsv32_tokenizer.__class__ = _DeepseekV32Tokenizer + return dsv32_tokenizer - @property - def max_token_id(self) -> int: - return self.tokenizer.max_token_id - @property - def max_chars_per_token(self) -> int: - return self.tokenizer.max_chars_per_token - - @property - def truncation_side(self) -> str: - return self.tokenizer.truncation_side - - def __hash__(self) -> int: - return hash(id(self)) - - def __len__(self) -> int: - # is an added token in DeepseekV32 tokenizer - return self.vocab_size + self._added_vocab_size - - def __call__( - self, - text: str | list[str], - text_pair: str | None = None, - add_special_tokens: bool = True, - truncation: bool = False, - max_length: int | None = None, - ) -> "BatchEncoding": - return self.tokenizer( - text, - text_pair=text_pair, - add_special_tokens=add_special_tokens, - truncation=truncation, - max_length=max_length, - ) - - def get_vocab(self) -> dict[str, int]: - return self.tokenizer.get_vocab() - - def get_added_vocab(self) -> dict[str, int]: - return self._added_vocab.copy() - - def encode( - self, - text: str, - truncation: bool | None = None, - max_length: int | None = None, - add_special_tokens: bool = True, - ) -> list[int]: - return self.tokenizer.encode( - text, - truncation=truncation, - max_length=max_length, - add_special_tokens=add_special_tokens, - ) - - @overload - def convert_tokens_to_ids(self, tokens: str) -> int: ... - - @overload - def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ... - - def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: - return self.tokenizer.convert_tokens_to_ids(tokens) - - def convert_tokens_to_string(self, tokens: list[str]) -> str: - return self.tokenizer.convert_tokens_to_string(tokens) - - def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str: - return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens) - - def convert_ids_to_tokens( - self, - ids: list[int], - skip_special_tokens: bool = False, - ) -> list[str]: - return self.tokenizer.convert_ids_to_tokens( - ids, skip_special_tokens=skip_special_tokens - ) +class DeepseekV32Tokenizer(TokenizerLike): + @classmethod + def from_pretrained(cls, *args, **kwargs) -> HfTokenizer: + tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs) + return get_cached_tokenizer(get_deepseek_v32_tokenizer(tokenizer)) diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 617132577..812c262a2 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -10,7 +10,6 @@ import torch import vllm.envs from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_types import ( @@ -57,27 +56,6 @@ class XgrammarBackend(StructuredOutputBackend): stop_token_ids=stop_token_ids, add_prefix_space=True, ) - elif isinstance(self.tokenizer, DeepseekV32Tokenizer): - # copy from xgr.TokenizerInfo.from_huggingface() - # because we are using a custom tokenizer wrapper here. - vocab_dict = self.tokenizer.get_vocab() - tokenizer_vocab_size = max(len(vocab_dict), self.tokenizer.max_token_id + 1) - vocab_size = self.vocab_size or tokenizer_vocab_size - # maintain tokenizer's indexing - encoded_vocab = [""] * vocab_size - for token, idx in vocab_dict.items(): - if idx < vocab_size: - encoded_vocab[idx] = token - stop_token_ids = [self.tokenizer.eos_token_id] - backend_str = self.tokenizer.tokenizer.backend_tokenizer.to_str() # type: ignore[attr-defined] - metadata = xgr.TokenizerInfo._detect_metadata_from_hf(backend_str) - tokenizer_info = xgr.TokenizerInfo( - encoded_vocab=encoded_vocab, - vocab_type=metadata["vocab_type"], - vocab_size=vocab_size, - stop_token_ids=stop_token_ids, - add_prefix_space=metadata["add_prefix_space"], - ) else: tokenizer_info = xgr.TokenizerInfo.from_huggingface( self.tokenizer,