diff --git a/requirements/common.txt b/requirements/common.txt index 472945d7b..bb1bb2dd9 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -22,7 +22,7 @@ lm-format-enforcer >= 0.10.11, < 0.11 llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines == 0.1.11 lark == 1.2.2 -xgrammar == 0.1.17; platform_machine == "x86_64" or platform_machine == "aarch64" +xgrammar == 0.1.18; platform_machine == "x86_64" or platform_machine == "aarch64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/vllm/envs.py b/vllm/envs.py index a561b52aa..f80bf878f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -106,6 +106,7 @@ if TYPE_CHECKING: VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False + VLLM_XGRAMMAR_CACHE_MB: int = 0 def get_default_cache_root(): @@ -697,6 +698,12 @@ environment_variables: dict[str, Callable[[], Any]] = { # Allow use of DeepGemm kernels for fused moe ops. "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + + # Control the cache sized used by the xgrammar compiler. The default + # of 512 MB should be enough for roughly 1000 JSON schemas. + # It can be changed with this variable if needed for some reason. + "VLLM_XGRAMMAR_CACHE_MB": + lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), } # end-env-vars-definition diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index b44301f1a..d7e600e9b 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List import torch +import vllm.envs from vllm.logger import init_logger try: @@ -131,8 +132,13 @@ class GrammarCompilerCache: encoded_vocab=config_data.encoded_vocab, metadata=config_data.metadata, ) + cache_size = vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024 cls._cache[cache_key] = xgr.GrammarCompiler( - tokenizer_info, max_threads=config.max_threads) + tokenizer_info, + max_threads=config.max_threads, + cache_enabled=True, + cache_limit_bytes=cache_size, + ) return cls._cache[cache_key] diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 783a33481..83f2c6436 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING import torch +import vllm.envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs @@ -76,7 +77,12 @@ class XgrammarBackend(StructuredOutputBackend): tokenizer, vocab_size=self.vocab_size, ) - self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + self.compiler = xgr.GrammarCompiler( + tokenizer_info, + max_threads=8, + cache_enabled=True, + cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024, + ) def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: