[V1][Core] Support for Structured Outputs (#12388)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Aaron Pham
2025-03-07 10:19:11 -05:00
committed by GitHub
parent 1e3598edeb
commit 80e9afb5bc
26 changed files with 1528 additions and 715 deletions

View File

@@ -29,6 +29,7 @@ from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@@ -61,6 +62,8 @@ class EngineCore:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config,
@@ -69,6 +72,7 @@ class EngineCore:
lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config,
log_stats=self.log_stats,
structured_output_manager=self.structured_output_manager,
)
# Setup MM Input Mapper.
@@ -131,6 +135,9 @@ class EngineCore:
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.populate_cache(req)
self.scheduler.add_request(req)
@@ -148,11 +155,24 @@ class EngineCore:
if not self.scheduler.has_unfinished_requests():
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
scheduler_output = self.scheduler.schedule()
# This case may occur when the only unfinished requests are
# structured output requests where the grammar has not finished
# compiling yet, so there's nothing to run.
if scheduler_output.total_num_scheduled_tokens == 0:
return EngineCoreOutputs(
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output) # type: ignore
return engine_core_outputs
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: