[V1][Core] Support for Structured Outputs (#12388)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -72,9 +72,7 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
input_registry=input_registry,
|
||||
)
|
||||
@@ -194,8 +192,8 @@ class AsyncLLM(EngineClient):
|
||||
* 3) Adding the Request to the Detokenizer.
|
||||
* 4) Adding the Request to the EngineCore (separate process).
|
||||
|
||||
A separate output_handler loop runs in a background AsyncIO task,
|
||||
pulling outputs from EngineCore and putting them into the
|
||||
A separate output_handler loop runs in a background AsyncIO task,
|
||||
pulling outputs from EngineCore and putting them into the
|
||||
per-request AsyncStream.
|
||||
|
||||
The caller of generate() iterates the returned AsyncGenerator,
|
||||
|
||||
@@ -29,6 +29,7 @@ from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -61,6 +62,8 @@ class EngineCore:
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
||||
|
||||
# Setup scheduler.
|
||||
self.scheduler = Scheduler(
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
@@ -69,6 +72,7 @@ class EngineCore:
|
||||
lora_config=vllm_config.lora_config,
|
||||
speculative_config=vllm_config.speculative_config,
|
||||
log_stats=self.log_stats,
|
||||
structured_output_manager=self.structured_output_manager,
|
||||
)
|
||||
|
||||
# Setup MM Input Mapper.
|
||||
@@ -131,6 +135,9 @@ class EngineCore:
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
if req.use_structured_output:
|
||||
# Start grammar compilation asynchronously
|
||||
self.structured_output_manager.populate_cache(req)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
@@ -148,11 +155,24 @@ class EngineCore:
|
||||
|
||||
if not self.scheduler.has_unfinished_requests():
|
||||
return EngineCoreOutputs(
|
||||
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
||||
outputs=[],
|
||||
scheduler_stats=self.scheduler.make_stats(),
|
||||
)
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
# This case may occur when the only unfinished requests are
|
||||
# structured output requests where the grammar has not finished
|
||||
# compiling yet, so there's nothing to run.
|
||||
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||
return EngineCoreOutputs(
|
||||
outputs=[],
|
||||
scheduler_stats=self.scheduler.make_stats(),
|
||||
)
|
||||
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, output) # type: ignore
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
|
||||
|
||||
@@ -66,9 +66,7 @@ class LLMEngine:
|
||||
self.tokenizer.ping()
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
self.processor = Processor(vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
input_registry=input_registry,
|
||||
mm_registry=mm_registry)
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType, SingletonInputsAdapter)
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
@@ -19,39 +19,41 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||
from vllm.v1.structured_output.utils import validate_structured_output_request
|
||||
|
||||
|
||||
class Processor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vllm_config: VllmConfig,
|
||||
tokenizer: BaseTokenizerGroup,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.decoding_config = vllm_config.decoding_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.generation_config_fields = model_config.try_get_generation_config(
|
||||
)
|
||||
self.input_preprocessor = InputPreprocessor(model_config,
|
||||
self.generation_config_fields = (
|
||||
self.model_config.try_get_generation_config())
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
model_config)
|
||||
self.model_config)
|
||||
|
||||
# Multi-modal (huggingface) input mapper
|
||||
self.mm_input_cache_client = MMInputCacheClient(model_config)
|
||||
self.mm_input_cache_client = MMInputCacheClient(self.model_config)
|
||||
|
||||
# Multi-modal hasher (for images)
|
||||
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
|
||||
cache_config.enable_prefix_caching
|
||||
self.use_hash = (
|
||||
not self.model_config.disable_mm_preprocessor_cache) or \
|
||||
self.cache_config.enable_prefix_caching
|
||||
|
||||
def _validate_logprobs(
|
||||
self,
|
||||
@@ -80,6 +82,8 @@ class Processor:
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
self._validate_structured_output(params)
|
||||
|
||||
if params.allowed_token_ids is None:
|
||||
return
|
||||
if not params.allowed_token_ids:
|
||||
@@ -125,6 +129,21 @@ class Processor:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
|
||||
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||||
if not params.guided_decoding or not self.decoding_config:
|
||||
return
|
||||
if self.decoding_config.guided_decoding_backend != "xgrammar":
|
||||
raise ValueError(
|
||||
"Only xgrammar structured output is supported in V1.")
|
||||
if (params.guided_decoding.backend
|
||||
and params.guided_decoding.backend != 'xgrammar'):
|
||||
raise ValueError(
|
||||
"Only xgrammar structured output is supported in V1.")
|
||||
if self.vllm_config.speculative_config:
|
||||
raise ValueError("Structured output is not supported with "
|
||||
"speculative decoding.")
|
||||
validate_structured_output_request(params)
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
|
||||
Reference in New Issue
Block a user