[Core][Performance] Add XGrammar support for guided decoding and set it as default (#10785)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: mgoin <michael@neuralmagic.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import time
|
||||
import weakref
|
||||
from functools import partial
|
||||
@@ -507,7 +508,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
sampling_params=params,
|
||||
tokenizer=await self.get_tokenizer_async(lora_request),
|
||||
default_guided_backend=self.decoding_config.
|
||||
guided_decoding_backend)
|
||||
guided_decoding_backend,
|
||||
model_config=self.model_config)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
@@ -528,22 +530,30 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
async def build_guided_decoding_logits_processor_async(
|
||||
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
||||
default_guided_backend: str) -> SamplingParams:
|
||||
default_guided_backend: str,
|
||||
model_config: ModelConfig) -> SamplingParams:
|
||||
"""Constructs logits processors based on the guided_decoding,
|
||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||
those fields and adds the constructed logits processors to the
|
||||
logits_processors field. Modifies sampling params in-place and returns
|
||||
the modified sampling params."""
|
||||
if (guided_decoding := sampling_params.guided_decoding) is None:
|
||||
if sampling_params.guided_decoding is None:
|
||||
return sampling_params
|
||||
|
||||
# Defensively copy sampling params since guided decoding logits
|
||||
# processors can have different state for each request
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
guided_decoding = sampling_params.guided_decoding
|
||||
|
||||
logger.debug("Building guided decoding logits processor. "
|
||||
"Params: %s", guided_decoding)
|
||||
|
||||
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
||||
|
||||
processor = await get_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding, tokenizer=tokenizer)
|
||||
guided_params=guided_decoding,
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config)
|
||||
|
||||
if processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
|
||||
Reference in New Issue
Block a user