[V1][Core] Support for Structured Outputs (#12388)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Aaron Pham
2025-03-07 10:19:11 -05:00
committed by GitHub
parent 1e3598edeb
commit 80e9afb5bc
26 changed files with 1528 additions and 715 deletions

View File

@@ -72,9 +72,7 @@ class AsyncLLM(EngineClient):
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
vllm_config=vllm_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
)
@@ -194,8 +192,8 @@ class AsyncLLM(EngineClient):
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,

View File

@@ -29,6 +29,7 @@ from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@@ -61,6 +62,8 @@ class EngineCore:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config,
@@ -69,6 +72,7 @@ class EngineCore:
lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config,
log_stats=self.log_stats,
structured_output_manager=self.structured_output_manager,
)
# Setup MM Input Mapper.
@@ -131,6 +135,9 @@ class EngineCore:
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.populate_cache(req)
self.scheduler.add_request(req)
@@ -148,11 +155,24 @@ class EngineCore:
if not self.scheduler.has_unfinished_requests():
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
scheduler_output = self.scheduler.schedule()
# This case may occur when the only unfinished requests are
# structured output requests where the grammar has not finished
# compiling yet, so there's nothing to run.
if scheduler_output.total_num_scheduled_tokens == 0:
return EngineCoreOutputs(
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output) # type: ignore
return engine_core_outputs
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:

View File

@@ -66,9 +66,7 @@ class LLMEngine:
self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
self.processor = Processor(vllm_config=vllm_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry)

View File

@@ -4,7 +4,7 @@ import time
from collections.abc import Mapping
from typing import Optional, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
@@ -19,39 +19,41 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.structured_output.utils import validate_structured_output_request
class Processor:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
vllm_config: VllmConfig,
tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.decoding_config = vllm_config.decoding_config
self.tokenizer = tokenizer
self.generation_config_fields = model_config.try_get_generation_config(
)
self.input_preprocessor = InputPreprocessor(model_config,
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.input_processor = input_registry.create_input_processor(
model_config)
self.model_config)
# Multi-modal (huggingface) input mapper
self.mm_input_cache_client = MMInputCacheClient(model_config)
self.mm_input_cache_client = MMInputCacheClient(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.use_hash = (
not self.model_config.disable_mm_preprocessor_cache) or \
self.cache_config.enable_prefix_caching
def _validate_logprobs(
self,
@@ -80,6 +82,8 @@ class Processor:
self,
params: SamplingParams,
) -> None:
self._validate_structured_output(params)
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
@@ -125,6 +129,21 @@ class Processor:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
if self.decoding_config.guided_decoding_backend != "xgrammar":
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if (params.guided_decoding.backend
and params.guided_decoding.backend != 'xgrammar'):
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if self.vllm_config.speculative_config:
raise ValueError("Structured output is not supported with "
"speculative decoding.")
validate_structured_output_request(params)
def process_inputs(
self,
request_id: str,