[Core] Streamline some structured output related code (#26737)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-10-14 16:27:44 -07:00
committed by GitHub
parent a86b4c58e8
commit 4aed506b65
13 changed files with 121 additions and 138 deletions

View File

@@ -5,7 +5,7 @@ import itertools
import time
from collections import defaultdict
from collections.abc import Iterable
from typing import Any
from typing import TYPE_CHECKING, Any
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
@@ -34,6 +34,10 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
logger = init_logger(__name__)
@@ -608,11 +612,8 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
scheduled_requests = (
scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs
)
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
scheduled_requests, scheduled_spec_decode_tokens
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
@@ -876,32 +877,28 @@ class Scheduler(SchedulerInterface):
def get_grammar_bitmask(
self,
requests: list[Request],
scheduled_request_ids: Iterable[str],
scheduled_spec_decode_tokens: dict[str, list[int]],
):
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to its index in the batch.
# This will help us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
for i, req in enumerate(requests):
if req.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[req.request_id] = i
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
# Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order.
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids = [
req_id
for req_id in scheduled_request_ids
if (req := self.requests.get(req_id)) and req.use_structured_output
]
if not structured_output_request_ids:
bitmask = None
else:
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
return structured_output_request_ids, None
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
return structured_output_request_ids, bitmask
def update_from_output(
@@ -1013,12 +1010,10 @@ class Scheduler(SchedulerInterface):
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and self.structured_output_manager.should_advance(request):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# checked above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids
)
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]