Signed-off-by: limingliang <limingliang@stepfun.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: limingliang <limingliang@stepfun.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
451 lines
14 KiB
Python
451 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tokenizer for Grok-2 .tok.json format."""
|
|
|
|
import functools
|
|
import json
|
|
from collections.abc import Collection, Set
|
|
from pathlib import Path
|
|
from typing import Any, Literal, overload
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.utils import (
|
|
EntryNotFoundError,
|
|
HfHubHTTPError,
|
|
RepositoryNotFoundError,
|
|
RevisionNotFoundError,
|
|
)
|
|
from transformers import BatchEncoding
|
|
from transformers.utils import chat_template_utils as hf_chat_utils
|
|
|
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
|
from vllm.logger import init_logger
|
|
|
|
from .protocol import TokenizerLike
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
PAD = "<|pad|>"
|
|
EOS = "<|eos|>"
|
|
SEP = "<|separator|>"
|
|
RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
|
|
CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
|
|
DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
|
|
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": SEP, "eos": EOS}
|
|
DEFAULT_CHAT_TEMPLATE = (
|
|
"{% for message in messages %}"
|
|
"{% if message['role'] == 'user' %}"
|
|
"{{ 'Human: ' + message['content'].strip() + '<|separator|>\\n\\n' }}"
|
|
"{% elif message['role'] == 'system' %}"
|
|
"{{ 'System: ' + message['content'].strip() + '<|separator|>\\n\\n' }}"
|
|
"{% elif message['role'] == 'assistant' %}"
|
|
"{{ 'Assistant: ' + message['content'] + '<|separator|>\\n\\n' }}"
|
|
"{% endif %}"
|
|
"{% endfor %}"
|
|
"{% if add_generation_prompt %}"
|
|
"{{ 'Assistant:' }}"
|
|
"{% endif %}"
|
|
)
|
|
|
|
# Default + separate each single digit.
|
|
PAT_STR_B = (
|
|
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|"""
|
|
r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
|
)
|
|
|
|
|
|
def _maybe_load_tokenizer_config(
|
|
model_path: Path,
|
|
*,
|
|
repo_id: str | None,
|
|
revision: str | None,
|
|
download_dir: str | None,
|
|
) -> dict[str, Any]:
|
|
config_path = model_path / "tokenizer_config.json"
|
|
if config_path.is_file():
|
|
with config_path.open("r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
if repo_id is None:
|
|
return {}
|
|
|
|
try:
|
|
config_file = hf_hub_download(
|
|
repo_id=repo_id,
|
|
filename="tokenizer_config.json",
|
|
revision=revision,
|
|
cache_dir=download_dir,
|
|
)
|
|
except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError):
|
|
# If the repo, revision, or file does not exist, fall back silently.
|
|
return {}
|
|
except HfHubHTTPError as exc:
|
|
logger.warning(
|
|
"Failed to download tokenizer_config.json from %s. "
|
|
"This may be due to a network or authentication issue. "
|
|
"The default chat template will be used. Error: %s",
|
|
repo_id,
|
|
exc,
|
|
)
|
|
return {}
|
|
|
|
try:
|
|
with Path(config_file).open("r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
except json.JSONDecodeError as exc:
|
|
logger.warning(
|
|
"Failed to parse tokenizer_config.json. "
|
|
"The default chat template will be used. Error: %s",
|
|
exc,
|
|
)
|
|
return {}
|
|
except OSError as exc:
|
|
logger.warning(
|
|
"Failed to open tokenizer_config.json. "
|
|
"The default chat template will be used. Error: %s",
|
|
exc,
|
|
)
|
|
return {}
|
|
|
|
|
|
def _load_tiktoken_encoding(
|
|
vocab_file: Path,
|
|
) -> tuple[Any, dict[str, int]]:
|
|
try:
|
|
import tiktoken
|
|
except ImportError as exc:
|
|
raise ImportError("Grok-2 tokenizer requires the `tiktoken` package.") from exc
|
|
|
|
with vocab_file.open("rb") as f:
|
|
xtok_dict = json.load(f)
|
|
|
|
mergeable_ranks = {
|
|
bytes(item["bytes"]): item["token"]
|
|
for item in xtok_dict.get("regular_tokens", [])
|
|
}
|
|
special_tokens = {
|
|
bytes(item["bytes"]).decode("utf-8", errors="replace"): item["token"]
|
|
for item in xtok_dict.get("special_tokens", [])
|
|
}
|
|
|
|
if xtok_dict.get("word_split") == "V1":
|
|
pat_str = PAT_STR_B
|
|
else:
|
|
raise ValueError(f"Unknown word_split: {xtok_dict.get('word_split')!r}")
|
|
|
|
pat_str = xtok_dict.get("pat_str", pat_str)
|
|
|
|
kwargs = {
|
|
"name": str(vocab_file),
|
|
"pat_str": pat_str,
|
|
"mergeable_ranks": mergeable_ranks,
|
|
"special_tokens": special_tokens,
|
|
}
|
|
|
|
if "vocab_size" in xtok_dict:
|
|
kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]
|
|
|
|
tokenizer = tiktoken.Encoding(**kwargs)
|
|
|
|
default_allowed_special: set[str] | None = None
|
|
if "default_allowed_special" in xtok_dict:
|
|
default_allowed_special = {
|
|
bytes(bytes_list).decode("utf-8", errors="replace")
|
|
for bytes_list in xtok_dict["default_allowed_special"]
|
|
}
|
|
|
|
tokenizer._default_allowed_special = default_allowed_special or set()
|
|
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
|
|
|
|
def encode_patched(
|
|
self,
|
|
text: str,
|
|
*,
|
|
allowed_special: Literal["all"] | Set[str] = set(),
|
|
disallowed_special: Literal["all"] | Collection[str] = "all",
|
|
) -> list[int]:
|
|
del disallowed_special
|
|
if isinstance(allowed_special, set):
|
|
allowed_special |= self._default_allowed_special
|
|
return tiktoken.Encoding.encode(
|
|
self,
|
|
text,
|
|
allowed_special=allowed_special,
|
|
disallowed_special=(),
|
|
)
|
|
|
|
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
|
tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
|
|
tokenizer._default_allowed_special |= set(
|
|
CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
|
|
)
|
|
|
|
return tokenizer, special_tokens
|
|
|
|
|
|
class Grok2Tokenizer(TokenizerLike):
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
path_or_repo_id: str | Path,
|
|
*args,
|
|
trust_remote_code: bool = False,
|
|
revision: str | None = None,
|
|
download_dir: str | None = None,
|
|
**kwargs,
|
|
) -> "Grok2Tokenizer":
|
|
if args:
|
|
logger.debug_once("Ignoring extra positional args for Grok2Tokenizer.")
|
|
|
|
path = Path(path_or_repo_id)
|
|
if path.is_file():
|
|
vocab_file = path
|
|
model_path = path.parent
|
|
repo_id = None
|
|
elif path.is_dir():
|
|
vocab_file = path / "tokenizer.tok.json"
|
|
model_path = path
|
|
repo_id = None
|
|
else:
|
|
vocab_file = Path(
|
|
hf_hub_download(
|
|
repo_id=str(path_or_repo_id),
|
|
filename="tokenizer.tok.json",
|
|
revision=revision,
|
|
cache_dir=download_dir,
|
|
)
|
|
)
|
|
model_path = vocab_file.parent
|
|
repo_id = str(path_or_repo_id)
|
|
|
|
if not vocab_file.is_file():
|
|
raise FileNotFoundError(f"tokenizer.tok.json not found at {vocab_file}.")
|
|
|
|
config = _maybe_load_tokenizer_config(
|
|
model_path,
|
|
repo_id=repo_id,
|
|
revision=revision,
|
|
download_dir=download_dir,
|
|
)
|
|
|
|
return cls(
|
|
vocab_file=vocab_file,
|
|
name_or_path=str(path_or_repo_id),
|
|
truncation_side=kwargs.get("truncation_side", "left"),
|
|
chat_template=config.get("chat_template"),
|
|
init_kwargs=config,
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
vocab_file: Path,
|
|
name_or_path: str,
|
|
truncation_side: str,
|
|
chat_template: str | None,
|
|
init_kwargs: dict[str, Any] | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.name_or_path = name_or_path
|
|
self._truncation_side = truncation_side
|
|
self.init_kwargs = init_kwargs or {}
|
|
self._chat_template = chat_template or DEFAULT_CHAT_TEMPLATE
|
|
|
|
self._tokenizer, self._special_tokens = _load_tiktoken_encoding(vocab_file)
|
|
|
|
self._token_to_id: dict[str, int] = {}
|
|
self._id_to_token: dict[int, str] = {}
|
|
for token, token_id in self._tokenizer._mergeable_ranks.items():
|
|
token_str = token.decode("utf-8", errors="replace")
|
|
self._token_to_id[token_str] = token_id
|
|
self._id_to_token[token_id] = token_str
|
|
|
|
for token, token_id in self._special_tokens.items():
|
|
self._token_to_id[token] = token_id
|
|
self._id_to_token[token_id] = token
|
|
|
|
bos_token_id = self._special_tokens.get(SEP)
|
|
if bos_token_id is None:
|
|
bos_token_id = self._special_tokens.get(PAD)
|
|
if bos_token_id is None:
|
|
bos_token_id = self._special_tokens.get(EOS)
|
|
if bos_token_id is None:
|
|
bos_token_id = 0
|
|
self._bos_token_id = bos_token_id
|
|
|
|
self._eos_token_id = self._special_tokens.get(EOS, self._bos_token_id)
|
|
self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id)
|
|
self._unk_token_id = self._pad_token_id
|
|
|
|
self._max_chars_per_token = max(len(tok) for tok in self._token_to_id)
|
|
|
|
def num_special_tokens_to_add(self) -> int:
|
|
return 0
|
|
|
|
@property
|
|
def all_special_tokens(self) -> list[str]:
|
|
return list(self._special_tokens.keys())
|
|
|
|
@property
|
|
def all_special_ids(self) -> list[int]:
|
|
return list(self._special_tokens.values())
|
|
|
|
@property
|
|
def bos_token_id(self) -> int:
|
|
return self._bos_token_id
|
|
|
|
@property
|
|
def eos_token_id(self) -> int:
|
|
return self._eos_token_id
|
|
|
|
@property
|
|
def pad_token_id(self) -> int:
|
|
return self._pad_token_id
|
|
|
|
@property
|
|
def is_fast(self) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def vocab_size(self) -> int:
|
|
return self._tokenizer.n_vocab
|
|
|
|
@property
|
|
def max_token_id(self) -> int:
|
|
return self._tokenizer.n_vocab - 1
|
|
|
|
@property
|
|
def max_chars_per_token(self) -> int:
|
|
return self._max_chars_per_token
|
|
|
|
@property
|
|
def truncation_side(self) -> str:
|
|
return self._truncation_side
|
|
|
|
def get_vocab(self) -> dict[str, int]:
|
|
return dict(self._token_to_id)
|
|
|
|
def get_added_vocab(self) -> dict[str, int]:
|
|
return dict(self._special_tokens)
|
|
|
|
def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
|
|
if max_length is None or len(tokens) <= max_length:
|
|
return tokens
|
|
if self.truncation_side == "left":
|
|
return tokens[-max_length:]
|
|
return tokens[:max_length]
|
|
|
|
def encode(
|
|
self,
|
|
text: str,
|
|
truncation: bool | None = None,
|
|
max_length: int | None = None,
|
|
add_special_tokens: bool = True,
|
|
) -> list[int]:
|
|
del add_special_tokens
|
|
tokens = self._tokenizer.encode(text)
|
|
if truncation:
|
|
tokens = self._maybe_truncate(tokens, max_length)
|
|
return tokens
|
|
|
|
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
|
|
if isinstance(ids, int):
|
|
ids = [ids]
|
|
if skip_special_tokens:
|
|
ids = [
|
|
token_id
|
|
for token_id in ids
|
|
if token_id not in self._special_tokens.values()
|
|
]
|
|
return self._tokenizer.decode(ids)
|
|
|
|
@overload
|
|
def convert_tokens_to_ids(self, tokens: str) -> int: ...
|
|
|
|
@overload
|
|
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
|
|
|
|
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
|
|
if isinstance(tokens, str):
|
|
return self._token_to_id.get(tokens, self._unk_token_id)
|
|
return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]
|
|
|
|
def convert_ids_to_tokens(
|
|
self, ids: list[int], skip_special_tokens: bool = False
|
|
) -> list[str]:
|
|
tokens = []
|
|
for token_id in ids:
|
|
if skip_special_tokens and token_id in self._special_tokens.values():
|
|
continue
|
|
tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
|
|
return tokens
|
|
|
|
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
|
token_ids = self.convert_tokens_to_ids(tokens)
|
|
return self.decode(token_ids, skip_special_tokens=False)
|
|
|
|
def __call__(
|
|
self,
|
|
text: str | list[str],
|
|
text_pair: str | None = None,
|
|
add_special_tokens: bool = True,
|
|
truncation: bool = False,
|
|
max_length: int | None = None,
|
|
) -> BatchEncoding:
|
|
if text_pair is not None:
|
|
raise NotImplementedError("text_pair is not supported for Grok2Tokenizer.")
|
|
|
|
if isinstance(text, list):
|
|
input_ids_batch: list[list[int]] = [
|
|
self.encode(
|
|
item,
|
|
truncation=truncation,
|
|
max_length=max_length,
|
|
add_special_tokens=add_special_tokens,
|
|
)
|
|
for item in text
|
|
]
|
|
attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
|
|
return BatchEncoding(
|
|
{"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
|
|
)
|
|
|
|
input_ids = self.encode(
|
|
text,
|
|
truncation=truncation,
|
|
max_length=max_length,
|
|
add_special_tokens=add_special_tokens,
|
|
)
|
|
attention_mask = [1] * len(input_ids)
|
|
return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
|
|
|
|
def get_chat_template(
|
|
self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
|
|
) -> str | None:
|
|
del tools
|
|
return chat_template or self._chat_template
|
|
|
|
def apply_chat_template(
|
|
self,
|
|
messages: list[ChatCompletionMessageParam],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
chat_template: str | None = None,
|
|
tokenize: bool = False,
|
|
**kwargs,
|
|
) -> str | list[int]:
|
|
template = self.get_chat_template(chat_template, tools=tools)
|
|
if template is None:
|
|
raise ValueError(
|
|
"No chat template available. Provide `chat_template` explicitly."
|
|
)
|
|
kwargs["return_dict"] = False
|
|
prompt = hf_chat_utils.apply_chat_template(
|
|
conversation=messages,
|
|
chat_template=template,
|
|
tools=tools,
|
|
**kwargs,
|
|
)
|
|
if tokenize:
|
|
return self.encode(prompt, add_special_tokens=False)
|
|
return prompt
|