Files
vllm/vllm/transformers_utils/tokenizers/mistral.py
Russell Bryant e489ad7a21 [Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

391 lines
14 KiB
Python
Raw Blame History

# SPDX-License-Identifier: Apache-2.0
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import SpecialTokens
# yapf: disable
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
# yapf: enable
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer)
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer)
from vllm.logger import init_logger
from vllm.utils import is_list_of
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
logger = init_logger(__name__)
@dataclass
class Encoding:
input_ids: Union[List[int], List[List[int]]]
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
# SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes
# NOTE: There is currently a bug in pydantic where attributes
# declared as iterables are replaced in in the instances by
# pydantic-core ValidatorIterator instance. In particular, this
# affects tool_calls defined in ChatCompletionAssistantMessageParam
# model:
# see:
# - https://github.com/pydantic/pydantic/issues/9467
# As a result, tool_calls from assistant messages are never
# deserialized in the request object if the tool_calls iterator is
# not consumed. This affect messages passed to the MistralTokenizer
# since no chat template is applied and therefore the tools_calls
# iterator is not directly consumed.
# Issue is tracked on Pydantic side, with resolution planned for
# v2.11 release. In the meantime, the official workaround is to
# consume the iterator so the tool_calls are correctly deserialized
# in the OpenAI ChatCompletionAssistantMessageParam object
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
# Official Pydantic Issues:
# - https://github.com/pydantic/pydantic/issues/9541
# TODO: remove when pydantic v2.11 is released
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls_validator = message.get("tool_calls", ().__iter__())
validated_tool_calls = []
while True:
try:
tool_call = next(tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
except StopIteration:
break
request.messages[i]["tool_calls"] = validated_tool_calls
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *repo_id.split("/")]))
if revision is None:
revision_file = os.path.join(repo_cache, "refs", "main")
if os.path.isfile(revision_file):
with open(revision_file) as file:
revision = file.read()
if revision:
revision_dir = os.path.join(repo_cache, "snapshots", revision)
if os.path.isdir(revision_dir):
return os.listdir(revision_dir)
return []
def find_tokenizer_file(files: List[str]):
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
matched_files = [file for file in files if file_pattern.match(file)]
if len(matched_files) > 1:
raise OSError(f"Found {len(matched_files)} files matching the "
f"pattern: {file_pattern}. Make sure only one Mistral "
f"tokenizer is present in {files}.")
elif len(matched_files) == 0:
raise OSError(f"Found {len(matched_files)} files matching the "
f"pattern: {file_pattern}. Make sure that a Mistral "
f"tokenizer is present in {files}.")
return matched_files[0]
class MistralTokenizer:
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
if self.is_tekken:
# Make sure special tokens will not raise
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
elif self.is_spm:
pass
else:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
self._vocab = tokenizer_.vocab()
# Convert to a Dict[str, int] to match protocol, but this is a lossy
# conversion. There may be multiple token ids that decode to the same
# string due to partial UTF-8 byte sequences being converted to <20>
self._vocab_dict = {
token: idx
for idx, token in enumerate(self._vocab)
}
self.tokenizer = tokenizer_
self._max_token_id = self.vocab_size - 1
@classmethod
def from_pretrained(cls,
path_or_repo_id: str,
*,
revision: Optional[str] = None) -> "MistralTokenizer":
if not Path(path_or_repo_id).exists():
assert len(path_or_repo_id.split("/")) == 2, (
"You have either provided a non-existent path: "
"{path_or_repo_id} or an invalid HF Hub repo id.")
tokenizer_file = cls._download_mistral_tokenizer_from_hf(
path_or_repo_id, revision)
elif Path(path_or_repo_id).is_dir():
tokenizer_file_name = find_tokenizer_file(
os.listdir(path_or_repo_id))
tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
else:
assert Path(
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
return cls(mistral_tokenizer)
@staticmethod
def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
revision: Optional[str]) -> str:
try:
hf_api = HfApi()
files = hf_api.list_repo_files(repo_id=tokenizer_name,
revision=revision)
except ConnectionError as exc:
files = list_local_repo_files(repo_id=tokenizer_name,
revision=revision)
if len(files) == 0:
raise exc
filename = find_tokenizer_file(files)
tokenizer_file = hf_hub_download(tokenizer_name,
filename=filename,
revision=revision)
return tokenizer_file
# the following attributes are set to fit VLLM's design and are used
# by the guided structured output backends.
@property
def all_special_tokens_extended(self) -> List[str]:
# tekken defines its own extended special tokens list
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
special_tokens = self.tokenizer.SPECIAL_TOKENS
else:
special_tokens = list(SpecialTokens)
return [
s.value if isinstance(s, SpecialTokens) else s
for s in special_tokens
]
@property
def all_special_tokens(self) -> List[str]:
return self.all_special_tokens_extended
@property
def all_special_ids(self) -> List[int]:
return [
self.all_special_tokens.index(t) for t in self.all_special_tokens
]
@property
def bos_token_id(self) -> int:
return self.tokenizer.bos_id
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_id
@property
def is_fast(self) -> bool:
return True
@property
def vocab_size(self) -> int:
return len(self._vocab)
@property
def max_token_id(self) -> int:
return self._max_token_id
def __len__(self) -> int:
return self.vocab_size
def __call__(
self,
prompt: Union[str, List[str], List[int]],
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
input_ids: Union[List[int], List[List[int]]]
# For List[str], original prompt text
if is_list_of(prompt, str):
input_ids_: List[List[int]] = []
for p in prompt:
each_input_ids = self.encode_one(p, truncation, max_length)
input_ids_.append(each_input_ids)
input_ids = input_ids_
# For List[int], apply chat template output, already tokens.
elif is_list_of(prompt, int):
input_ids = prompt
# For str, single prompt text
else:
input_ids = self.encode_one(prompt, truncation, max_length)
return Encoding(input_ids=input_ids)
def get_vocab(self) -> Dict[str, int]:
# NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes
return self._vocab_dict
def get_added_vocab(self) -> Dict[str, int]:
# Mistral tokenizers have no added vocabulary
return {}
def encode_one(
self,
prompt: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(prompt)
if truncation:
input_ids = input_ids[:max_length]
return input_ids
def encode(self, prompt: str) -> List[int]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False)
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
last_message = cast(Dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True
request = ChatCompletionRequest(messages=messages,
tools=tools) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)
# encode-decode to get clean prompt
return encoded.tokens
def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self.is_tekken:
tokens = [
t for t in tokens
if (t is SpecialTokens.tool_calls
or t not in self.tokenizer._all_special_tokens)
]
if any(isinstance(t, bytes) for t in tokens):
# we need to encode and decode all tokens again
shift = self.tokenizer.num_special_tokens
def _token_to_id(t: str):
t_bytes = t.encode("utf-8") \
if not isinstance(t, bytes) else t
try:
return shift + \
self.tokenizer._tekken_token2id_nospecial[t_bytes]
except KeyError:
logger.warning(
"Failed to convert token %s to id,"
" replacing with <unk>", t_bytes)
return self.tokenizer.unk_id
ids = [_token_to_id(t) for t in tokens]
decoded = self.tokenizer.decode(ids)
else:
decoded = "".join(tokens)
else:
# make sure certain special tokens like Tool calls are
# not decoded
special_tokens = {SpecialTokens.tool_calls}
regular_tokens: List[str] = []
decoded_list = []
for token in tokens:
if token in special_tokens:
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens))
regular_tokens = []
decoded_list.append(token)
else:
regular_tokens.append(token)
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens)) # type: ignore
decoded = ''.join(decoded_list)
return decoded
# WARN: Outlines logits processors can overwrite this method.
# See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
# for more.
def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
assert (
skip_special_tokens
), "skip_special_tokens=False is not supported for Mistral tokenizers."
if isinstance(ids, int):
ids = [ids]
return self.tokenizer.decode(ids)
def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert (
skip_special_tokens
), "skip_special_tokens=False is not supported for Mistral tokenizers."
assert self.is_tekken or self.is_spm, type(self.tokenizer)
if self.is_tekken:
# skip special tokens except tool call
ids = [
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
]
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
if any("<EFBFBD>" in t for t in tokens) and self.is_tekken:
# if a decoded token contains the replacement character, then the
# token has an incomplete UTF-8 character so we must use bytes
# See: https://github.com/vllm-project/vllm/pull/8640
# https://github.com/vllm-project/vllm/pull/9625
# if underlying tokenizeir is sentencepiece, we just add "<22>"
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
return tokens