[V1] Structured Outputs + Thinking compatibility (#16577)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Aaron Pham
2025-05-14 18:45:24 -04:00
committed by GitHub
parent d93c976a0d
commit 2fc9075b82
10 changed files with 233 additions and 75 deletions

View File

@@ -7,16 +7,23 @@ from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar)
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import torch
from vllm.reasoning import ReasoningParser
from vllm.v1.request import Request
else:
torch = LazyLoader("torch", globals(), "torch")
logger = init_logger(__name__)
@@ -26,9 +33,11 @@ class StructuredOutputManager:
def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None
self.reasoner: Optional[ReasoningParser] = None
self.vllm_config = vllm_config
self._grammar_bitmask: Optional[torch.Tensor] = None
self._full_mask = torch.tensor(-1, dtype=torch.int32)
# The default max_workers if not specified is the number of CPUs * 5,
# which is way too high since these tasks are CPU-bound, not I/O bound.
@@ -36,24 +45,43 @@ class StructuredOutputManager:
# compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config,
).get_lora_tokenizer(None)
reasoning_backend = vllm_config.decoding_config.reasoning_backend
if reasoning_backend:
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None:
return
if TYPE_CHECKING:
assert request.sampling_params.guided_decoding is not None
# Initialize the backend the first time it is needed.
#
# NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...).
if self.backend is None:
backend = request.sampling_params.guided_decoding.backend
vocab_size = self.vllm_config.model_config.get_vocab_size()
if backend == "xgrammar":
from vllm.v1.structured_output.backend_xgrammar import (
XgrammarBackend)
self.backend = XgrammarBackend(self.vllm_config)
self.backend = XgrammarBackend(
self.vllm_config,
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
elif backend == "guidance":
self.backend = GuidanceBackend(self.vllm_config)
self.backend = GuidanceBackend(
self.vllm_config,
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
else:
raise ValueError(
f"Unsupported structured output backend: {backend}")
@@ -87,14 +115,14 @@ class StructuredOutputManager:
if not structured_output_request_ids:
return None
max_num_spec_tokens = 0
if self.vllm_config.speculative_config is not None:
max_num_spec_tokens = \
self.vllm_config.speculative_config.num_speculative_tokens
if self._grammar_bitmask is None:
assert self.backend is not None
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
if self.vllm_config.speculative_config is not None:
max_num_spec_tokens = self.vllm_config.\
speculative_config.num_speculative_tokens
else:
max_num_spec_tokens = 0
# Allocate a bitmask for each token needing to be checked:
# one for each speculative position, and one more for the
@@ -103,6 +131,7 @@ class StructuredOutputManager:
self.backend.allocate_token_bitmask(
max_batch_size * (1 + max_num_spec_tokens))
bitmask_tensor = self._grammar_bitmask
# Generate a batched bitmask for all structured output requests.
# When speculative decoding is enabled, we need to include multiple
# masks for each request, one for each possible bonus token position.
@@ -110,16 +139,30 @@ class StructuredOutputManager:
cumulative_index = 0
ordered_seq = sorted(structured_output_request_ids.items(),
key=lambda x: x[1])
# Note that for thinking support, we will need to
# reset the relevant part of the bitmask for consequent
# request here.
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
self._full_mask)
# NOTE: This outer loop can likely be parallelized to improve
# performance of bitmask generation for large batches.
for req_id, _ in ordered_seq:
request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None
if TYPE_CHECKING:
assert request is not None
assert request.grammar is not None
apply_bitmask = (
request.reasoning_ended if self.reasoner is not None else True
) # noqa: E501
state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
for i, token in enumerate(req_tokens):
if not request.grammar.is_terminated():
request.grammar.fill_bitmask(self._grammar_bitmask,
if apply_bitmask and not request.grammar.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor,
cumulative_index)
if token is not None:
# In order to generate the correct bitmask for each
@@ -132,15 +175,41 @@ class StructuredOutputManager:
if state_advancements > 0:
request.grammar.rollback(state_advancements)
bitmask_tensor = self._grammar_bitmask
if cumulative_index < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:cumulative_index]
if cumulative_index < bitmask_tensor.shape[0]:
bitmask_tensor = bitmask_tensor[:cumulative_index]
# After finishing with the xgrammar operations, we convert to
# np.ndarray, because that is much more efficient for serialization
# and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy()
def should_advance(self, request: Request) -> bool:
if not request.use_structured_output:
return False
# To determine whether we can advance the FSM.
# Supports thinking usage where we skip the reasoning components.
if TYPE_CHECKING:
assert request.structured_output_request is not None
assert request.structured_output_request.grammar is not None
# by default, we should always advance
# for cases that doesn't uses thinking mode.
if self.reasoner is not None:
structured_req = request.structured_output_request
if structured_req.reasoning_ended:
return True
# Check if reasoning ends in *this* step
if self.reasoner.is_reasoning_end(request.all_token_ids):
# Reasoning just ended, so we shouldn't advanced til
# next pass
structured_req.reasoning_ended = True
return False
else:
return True
def clear_backend(self) -> None:
if self.backend is not None:
self.backend.destroy()