166 lines
7.1 KiB
Python
166 lines
7.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import multiprocessing
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
|
from vllm.utils import LazyLoader
|
|
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
|
|
|
|
if TYPE_CHECKING:
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import xgrammar as xgr
|
|
|
|
from vllm.v1.request import Request
|
|
else:
|
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class StructuredOutputManager:
|
|
|
|
def __init__(self, vllm_config: VllmConfig):
|
|
self.vllm_config = vllm_config
|
|
self.init_complete = False
|
|
|
|
def _delayed_init(self):
|
|
"""Initialization delayed until we know it is needed."""
|
|
tokenizer_group = init_tokenizer_from_configs(
|
|
model_config=self.vllm_config.model_config,
|
|
scheduler_config=self.vllm_config.scheduler_config,
|
|
parallel_config=self.vllm_config.parallel_config,
|
|
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
|
|
tokenizer_group.ping()
|
|
|
|
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
|
self.vocab_size = len(tokenizer.get_vocab())
|
|
if isinstance(tokenizer, MistralTokenizer):
|
|
# NOTE: ideally, xgrammar should handle this accordingly.
|
|
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
|
try:
|
|
encoded_vocab = [
|
|
token for token, _ in sorted(
|
|
tokenizer.get_vocab().items(),
|
|
key=lambda x: x[1],
|
|
)
|
|
]
|
|
stop_token_ids = None
|
|
if hasattr(
|
|
tokenizer,
|
|
"eos_token_id",
|
|
) and tokenizer.eos_token_id is not None:
|
|
stop_token_ids = [tokenizer.eos_token_id]
|
|
except AttributeError as e:
|
|
raise ValueError(
|
|
f"Cannot get the vocabulary of the tokenizer "
|
|
f"{type(tokenizer)}. The tokenizer should have a "
|
|
"get_vocab method.") from e
|
|
tokenizer_info = xgr.TokenizerInfo(
|
|
encoded_vocab=encoded_vocab,
|
|
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
|
vocab_type=xgr.VocabType.BYTE_FALLBACK,
|
|
vocab_size=self.vocab_size,
|
|
stop_token_ids=stop_token_ids,
|
|
add_prefix_space=True,
|
|
)
|
|
else:
|
|
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
|
tokenizer,
|
|
vocab_size=self.vocab_size,
|
|
)
|
|
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
|
|
|
|
# The default max_workers if not specified is the number of CPUs * 5,
|
|
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
|
# We also know we would never dominate CPU usage with just grammar
|
|
# compilation, so we set it to half the number of CPUs.
|
|
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
|
self._grammar_bitmask = xgr.allocate_token_bitmask(
|
|
self.vllm_config.scheduler_config.max_num_seqs,
|
|
self.vocab_size,
|
|
)
|
|
|
|
self.init_complete = True
|
|
|
|
def grammar_init(self, request: Request) -> None:
|
|
if request.structured_output_request is None:
|
|
return
|
|
|
|
# The first time this is called, we need to finish initialization
|
|
# of xgrammar. We defer it to avoid the import of xgrammar and
|
|
# initialization cost if it is not going to be used.
|
|
if not self.init_complete:
|
|
self._delayed_init()
|
|
|
|
grammar: Future[Grammar] = self.executor.submit(
|
|
self._async_create_grammar, request)
|
|
request.structured_output_request.grammar = grammar # type: ignore[assignment]
|
|
|
|
def _async_create_grammar(self, request: Request) -> Grammar:
|
|
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
|
|
|
|
# Note that the request was validated in the engine core client,
|
|
# so at this point we know it is a supported type of request.
|
|
#
|
|
# TODO: we still need to handle xgrammar compilation failures,
|
|
# though it should be unlikely as we test that up front as well.
|
|
request_type, grammar_spec = key
|
|
|
|
if request_type == StructuredOutputOptions.JSON:
|
|
# TODO -- allow any_whitespace to be configurable
|
|
# pending merge of https://github.com/vllm-project/vllm/pull/12744
|
|
ctx = self.compiler.compile_json_schema(grammar_spec,
|
|
any_whitespace=False)
|
|
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
|
ctx = self.compiler.compile_builtin_json_grammar()
|
|
elif request_type == StructuredOutputOptions.GRAMMAR:
|
|
ctx = self.compiler.compile_grammar(grammar_spec)
|
|
elif request_type == StructuredOutputOptions.REGEX:
|
|
ctx = self.compiler.compile_regex(grammar_spec)
|
|
else:
|
|
logger.error("Validation should have already occurred. "
|
|
"Please file an issue.")
|
|
raise ValueError(
|
|
f"grammar is not of valid supported types. ({request_type!s})")
|
|
|
|
return Grammar(
|
|
matcher=xgr.GrammarMatcher(ctx),
|
|
vocab_size=self.vocab_size,
|
|
ctx=ctx,
|
|
)
|
|
|
|
def grammar_bitmask(
|
|
self,
|
|
requests: dict[str, Request],
|
|
structured_output_request_ids: dict[str, int],
|
|
batch_len: int,
|
|
) -> Optional[npt.NDArray[np.int32]]:
|
|
# Prepare the structured output bitmask for this batch.
|
|
if not structured_output_request_ids:
|
|
return None
|
|
|
|
# Fill the bitmask using the index of each request equal to its
|
|
# position in the batch. Resize the bitmask down to the size of
|
|
# the batch.
|
|
bitmask_tensor = self._grammar_bitmask
|
|
for req_id, batch_index in structured_output_request_ids.items():
|
|
request = requests[req_id].structured_output_request
|
|
assert request is not None and request.grammar is not None
|
|
if not request.grammar.matcher.is_terminated():
|
|
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
|
|
if batch_len < self._grammar_bitmask.shape[0]:
|
|
bitmask_tensor = self._grammar_bitmask[:batch_len]
|
|
|
|
# After finishing with the xgrammar operations, we convert to
|
|
# np.ndarray, because that is much more efficient for serialization
|
|
# and deserialization when sending this to the GPU workers.
|
|
return bitmask_tensor.numpy()
|