[Bugfix]: Fix TokenizerLike interface (#30009)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2025-12-05 22:56:40 -06:00
committed by GitHub
parent e858bc4d14
commit 40a046cd82
8 changed files with 78 additions and 52 deletions

View File

@@ -32,7 +32,6 @@ from typing import Any, cast
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from transformers import PreTrainedTokenizerBase
from typing_extensions import deprecated from typing_extensions import deprecated
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@@ -189,7 +188,7 @@ class BenchmarkDataset(ABC):
@abstractmethod @abstractmethod
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
request_id_prefix: str = "", request_id_prefix: str = "",
no_oversample: bool = False, no_oversample: bool = False,
@@ -201,7 +200,7 @@ class BenchmarkDataset(ABC):
for generating a list of SampleRequest objects. for generating a list of SampleRequest objects.
Args: Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used tokenizer (TokenizerLike): The tokenizer to be used
for processing the dataset's text. for processing the dataset's text.
num_requests (int): The number of sample requests to generate. num_requests (int): The number of sample requests to generate.
request_id_prefix (str): The prefix of request_id. request_id_prefix (str): The prefix of request_id.
@@ -380,7 +379,7 @@ def process_video(video: Any) -> Mapping[str, Any]:
def gen_prompt_decode_to_target_len( def gen_prompt_decode_to_target_len(
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
token_sequence: list[int], token_sequence: list[int],
target_token_len: int, target_token_len: int,
max_retry: int = 10, max_retry: int = 10,
@@ -468,7 +467,7 @@ class RandomDataset(BenchmarkDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
request_id_prefix: str = "", request_id_prefix: str = "",
no_oversample: bool = False, no_oversample: bool = False,
@@ -580,7 +579,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float, range_ratio: float,
input_len: int, input_len: int,
output_len: int, output_len: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
""" """
Get the sampling parameters for the dataset. Get the sampling parameters for the dataset.
@@ -626,7 +625,7 @@ class RandomDataset(BenchmarkDataset):
def generate_token_sequence( def generate_token_sequence(
self, self,
*, *,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
prefix_token_ids: list[int], prefix_token_ids: list[int],
prefix_len: int, prefix_len: int,
vocab_size: int, vocab_size: int,
@@ -686,7 +685,7 @@ class RandomDatasetForReranking(RandomDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
request_id_prefix: str = "", request_id_prefix: str = "",
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
@@ -716,7 +715,11 @@ class RandomDatasetForReranking(RandomDataset):
doc_lens, _, doc_offsets = self.get_sampling_params( doc_lens, _, doc_offsets = self.get_sampling_params(
num_requests, range_ratio, doc_len_param, 0, tokenizer num_requests, range_ratio, doc_len_param, 0, tokenizer
) )
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
prohibited_tokens = tokenizer.all_special_ids
all_tokens = np.arange(vocab_size)
allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
query_prompt, query_input_len, token_mismatch_total = ( query_prompt, query_input_len, token_mismatch_total = (
self.generate_token_sequence( self.generate_token_sequence(
@@ -727,6 +730,7 @@ class RandomDatasetForReranking(RandomDataset):
input_len=query_len, input_len=query_len,
offset=int(query_offsets[0]), offset=int(query_offsets[0]),
index=0, index=0,
allowed_tokens=allowed_tokens,
) )
) )
@@ -740,6 +744,7 @@ class RandomDatasetForReranking(RandomDataset):
input_len=int(doc_lens[i]), input_len=int(doc_lens[i]),
offset=int(doc_offsets[i]), offset=int(doc_offsets[i]),
index=i + 1, index=i + 1,
allowed_tokens=allowed_tokens,
) )
token_mismatch_total += token_mismatch token_mismatch_total += token_mismatch
requests.append((prompt, total_input_len)) requests.append((prompt, total_input_len))
@@ -1077,7 +1082,7 @@ class RandomMultiModalDataset(RandomDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
request_id_prefix: str = "", request_id_prefix: str = "",
no_oversample: bool = False, no_oversample: bool = False,
@@ -1231,7 +1236,7 @@ class ShareGPTDataset(BenchmarkDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
lora_path: str | None = None, lora_path: str | None = None,
max_loras: int | None = None, max_loras: int | None = None,
@@ -1633,7 +1638,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
) )
def get_samples(args, tokenizer) -> list[SampleRequest]: def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
if not hasattr(args, "request_id_prefix"): if not hasattr(args, "request_id_prefix"):
args.request_id_prefix = "" args.request_id_prefix = ""
@@ -1971,7 +1976,7 @@ class CustomDataset(BenchmarkDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
lora_path: str | None = None, lora_path: str | None = None,
max_loras: int | None = None, max_loras: int | None = None,
@@ -2101,7 +2106,7 @@ class SonnetDataset(BenchmarkDataset):
def sample( def sample(
self, self,
tokenizer, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN, prefix_len: int = DEFAULT_PREFIX_LEN,
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
@@ -2202,7 +2207,7 @@ class BurstGPTDataset(BenchmarkDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
max_loras: int | None = None, max_loras: int | None = None,
lora_path: str | None = None, lora_path: str | None = None,
@@ -2287,7 +2292,7 @@ class ConversationDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
@@ -2347,7 +2352,7 @@ class MultiModalConversationDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
@@ -2416,7 +2421,7 @@ class VisionArenaDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
@@ -2470,7 +2475,7 @@ class MMVUDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
@@ -2531,7 +2536,7 @@ class InstructCoderDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
@@ -2595,7 +2600,7 @@ class MTBenchDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
@@ -2661,7 +2666,7 @@ class BlazeditDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
skip_chat_template: bool = False, skip_chat_template: bool = False,
@@ -2742,7 +2747,7 @@ class AIMODataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
request_id_prefix: str = "", request_id_prefix: str = "",
@@ -2852,7 +2857,7 @@ class NextEditPredictionDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
request_id_prefix: str = "", request_id_prefix: str = "",
no_oversample: bool = False, no_oversample: bool = False,
@@ -2924,7 +2929,7 @@ class ASRDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
request_id_prefix: str = "", request_id_prefix: str = "",
@@ -3002,7 +3007,7 @@ class MLPerfDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
request_id_prefix: str = "", request_id_prefix: str = "",
@@ -3081,7 +3086,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN, prefix_len: int = DEFAULT_PREFIX_LEN,
suffix_len: int = DEFAULT_SUFFIX_LEN, suffix_len: int = DEFAULT_SUFFIX_LEN,
@@ -3167,7 +3172,7 @@ class MMStarDataset(HuggingFaceDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
num_requests: int, num_requests: int,
output_len: int | None = None, output_len: int | None = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,

View File

@@ -36,7 +36,6 @@ from typing import Any, Literal
import aiohttp import aiohttp
import numpy as np import numpy as np
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from vllm.benchmarks.datasets import SampleRequest, add_dataset_parser, get_samples from vllm.benchmarks.datasets import SampleRequest, add_dataset_parser, get_samples
from vllm.benchmarks.lib.endpoint_request_func import ( from vllm.benchmarks.lib.endpoint_request_func import (
@@ -47,7 +46,7 @@ from vllm.benchmarks.lib.endpoint_request_func import (
) )
from vllm.benchmarks.lib.ready_checker import wait_for_endpoint from vllm.benchmarks.lib.ready_checker import wait_for_endpoint
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.network_utils import join_host_port from vllm.utils.network_utils import join_host_port
@@ -286,7 +285,7 @@ def calculate_metrics(
input_requests: list[SampleRequest], input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput], outputs: list[RequestFuncOutput],
dur_s: float, dur_s: float,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
selected_percentiles: list[float], selected_percentiles: list[float],
goodput_config_dict: dict[str, float], goodput_config_dict: dict[str, float],
) -> tuple[BenchmarkMetrics, list[int]]: ) -> tuple[BenchmarkMetrics, list[int]]:
@@ -489,7 +488,7 @@ async def benchmark(
base_url: str, base_url: str,
model_id: str, model_id: str,
model_name: str, model_name: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
input_requests: list[SampleRequest], input_requests: list[SampleRequest],
logprobs: int | None, logprobs: int | None,
request_rate: float, request_rate: float,
@@ -1032,6 +1031,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
type=str, type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
) )
parser.add_argument(
"--tokenizer-mode",
type=str,
default="auto",
help="""Tokenizer mode:\n
- "auto" will use the tokenizer from `mistral_common` for Mistral models
if available, otherwise it will use the "hf" tokenizer.\n
- "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- Other custom values can be supported via plugins.""",
)
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument( parser.add_argument(
"--logprobs", "--logprobs",
@@ -1228,18 +1240,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Common prefix length shared by all prompts (used by random dataset)", help="Common prefix length shared by all prompts (used by random dataset)",
) )
parser.add_argument(
"--tokenizer-mode",
type=str,
default="auto",
choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
"always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.',
)
parser.add_argument( parser.add_argument(
"--served-model-name", "--served-model-name",
type=str, type=str,

View File

@@ -14,7 +14,7 @@ from typing import Any
import torch import torch
import uvloop import uvloop
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import ( from vllm.benchmarks.datasets import (
AIMODataset, AIMODataset,
@@ -35,6 +35,7 @@ from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
@@ -246,12 +247,15 @@ async def run_vllm_async(
def run_hf( def run_hf(
requests: list[SampleRequest], requests: list[SampleRequest],
model: str, model: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: TokenizerLike,
n: int, n: int,
max_batch_size: int, max_batch_size: int,
trust_remote_code: bool, trust_remote_code: bool,
disable_detokenize: bool = False, disable_detokenize: bool = False,
) -> float: ) -> float:
assert isinstance(tokenizer, PreTrainedTokenizerBase), (
"the hf backend only supports HF tokenizers"
)
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
model, dtype=torch.float16, trust_remote_code=trust_remote_code model, dtype=torch.float16, trust_remote_code=trust_remote_code
) )
@@ -692,15 +696,21 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
if args.tokenizer is None:
args.tokenizer = args.model
validate_args(args) validate_args(args)
if args.seed is None: if args.seed is None:
args.seed = 0 args.seed = 0
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = AutoTokenizer.from_pretrained( if (
args.tokenizer, trust_remote_code=args.trust_remote_code args.backend == "hf" or args.backend == "mii"
) and args.tokenizer_mode == "auto":
# mistral_common tokenizer is only supported on vllm and vllm-chat backends;
# for hf and mii backends, we use hf tokenizer
args.tokenizer_mode = "hf"
tokenizer = get_tokenizer(
args.tokenizer,
tokenizer_mode=args.tokenizer_mode,
trust_remote_code=args.trust_remote_code,
) )
requests = get_requests(args, tokenizer) requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None for request in requests) is_multi_modal = any(request.multi_modal_data is not None for request in requests)

View File

@@ -136,7 +136,8 @@ class ModelConfig:
name or path will be used.""" name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto" tokenizer_mode: TokenizerMode | str = "auto"
"""Tokenizer mode:\n """Tokenizer mode:\n
- "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.\n - "auto" will use the tokenizer from `mistral_common` for Mistral models
if available, otherwise it will use the "hf" tokenizer.\n
- "hf" will use the fast tokenizer if available.\n - "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n - "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n - "mistral" will always use the tokenizer from `mistral_common`.\n

View File

@@ -54,6 +54,9 @@ class DeepseekV32Tokenizer(HfTokenizer):
prompt_str = encode_messages(messages, **encode_config) # type: ignore prompt_str = encode_messages(messages, **encode_config) # type: ignore
return prompt_str return prompt_str
def num_special_tokens_to_add(self) -> int:
return len(self.encode(""))
@property @property
def all_special_tokens(self) -> list[str]: def all_special_tokens(self) -> list[str]:
return self.tokenizer.all_special_tokens return self.tokenizer.all_special_tokens

View File

@@ -309,6 +309,9 @@ class MistralTokenizer(TokenizerLike):
for i in all_special_ids for i in all_special_ids
] ]
def num_special_tokens_to_add(self) -> int:
return len(self.encode(""))
# the following attributes are set to fit vLLM's design and are used # the following attributes are set to fit vLLM's design and are used
# by the structured output backends. # by the structured output backends.
@property @property
@@ -421,6 +424,7 @@ class MistralTokenizer(TokenizerLike):
) -> list[int]: ) -> list[int]:
add_generation_prompt = kwargs.pop("add_generation_prompt", False) add_generation_prompt = kwargs.pop("add_generation_prompt", False)
continue_final_message = kwargs.get("continue_final_message", False) continue_final_message = kwargs.get("continue_final_message", False)
tokenize = kwargs.get("tokenize", True)
padding = kwargs.get("padding", False) padding = kwargs.get("padding", False)
truncation = kwargs.get("truncation", False) truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length") max_length = kwargs.get("max_length")
@@ -433,7 +437,7 @@ class MistralTokenizer(TokenizerLike):
conversation=messages, conversation=messages,
tools=tools, tools=tools,
continue_final_message=continue_final_message, continue_final_message=continue_final_message,
tokenize=True, tokenize=tokenize,
padding=padding, padding=padding,
truncation=truncation, truncation=truncation,
max_length=max_length, max_length=max_length,

View File

@@ -22,6 +22,9 @@ class TokenizerLike(Protocol):
) -> "TokenizerLike": ) -> "TokenizerLike":
raise NotImplementedError raise NotImplementedError
def num_special_tokens_to_add(self) -> int:
raise NotImplementedError
@property @property
def all_special_tokens(self) -> list[str]: def all_special_tokens(self) -> list[str]:
raise NotImplementedError raise NotImplementedError

View File

@@ -183,7 +183,7 @@ def get_tokenizer(
"`tokenizer_mode='custom'` when initializing vLLM.", "`tokenizer_mode='custom'` when initializing vLLM.",
tokenizer_args, tokenizer_args,
str(tokenizer_kwargs), str(tokenizer_kwargs),
tokenizer_mode, tokenizer_name,
) )
tokenizer_mode = str(tokenizer_name) tokenizer_mode = str(tokenizer_name)