[V1][Core] Support MistralTokenizer for Structured Output (#14625)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham
2025-03-11 22:40:09 -04:00
committed by GitHub
parent 80e78d02ac
commit 77a318bd01
2 changed files with 102 additions and 26 deletions

View File

@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
@@ -40,8 +41,40 @@ class StructuredOutputManager:
tokenizer_group.ping()
tokenizer = tokenizer_group.get_lora_tokenizer(None)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer, vocab_size=self.vocab_size)
if isinstance(tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
encoded_vocab = [
token for token, _ in sorted(
tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if hasattr(
tokenizer,
"eos_token_id",
) and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
tokenizer_info = xgr.TokenizerInfo(
encoded_vocab=encoded_vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.BYTE_FALLBACK,
vocab_size=self.vocab_size,
stop_token_ids=stop_token_ids,
add_prefix_space=True,
)
else:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
# The default max_workers if not specified is the number of CPUs * 5,
@@ -51,7 +84,9 @@ class StructuredOutputManager:
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
self.vllm_config.scheduler_config.max_num_seqs,
self.vocab_size,
)
self.init_complete = True