[Core][Performance] Add XGrammar support for guided decoding and set it as default (#10785)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: mgoin <michael@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Aaron Pham
2024-12-03 02:17:00 -05:00
committed by GitHub
parent 3257d449fa
commit 9323a3153b
11 changed files with 385 additions and 33 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
import copy
import time
import weakref
from functools import partial
@@ -507,7 +508,8 @@ class _AsyncLLMEngine(LLMEngine):
sampling_params=params,
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
guided_decoding_backend,
model_config=self.model_config)
self._add_processed_request(
request_id=request_id,
@@ -528,22 +530,30 @@ class _AsyncLLMEngine(LLMEngine):
async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams:
default_guided_backend: str,
model_config: ModelConfig) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None:
if sampling_params.guided_decoding is None:
return sampling_params
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding
logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)
guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=model_config)
if processor:
if sampling_params.logits_processors is None: