[Core] Streamline some structured output related code (#26737)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import itertools
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
@@ -34,6 +34,10 @@ from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -608,11 +612,8 @@ class Scheduler(SchedulerInterface):
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
scheduled_requests = (
|
||||
scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs
|
||||
)
|
||||
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
|
||||
scheduled_requests, scheduled_spec_decode_tokens
|
||||
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
|
||||
)
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
@@ -876,32 +877,28 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
def get_grammar_bitmask(
|
||||
self,
|
||||
requests: list[Request],
|
||||
scheduled_request_ids: Iterable[str],
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
):
|
||||
# NOTE: structured_output_request_ids maps
|
||||
# a request's (request that uses structured output)
|
||||
# request_id to its index in the batch.
|
||||
# This will help us determine to slice the grammar bitmask
|
||||
# and only applies valid mask for requests that
|
||||
# uses structured decoding.
|
||||
structured_output_request_ids: dict[str, int] = {}
|
||||
for i, req in enumerate(requests):
|
||||
if req.use_structured_output:
|
||||
# PERF: in case of chunked prefill,
|
||||
# request might not include any new tokens.
|
||||
# Therefore, we might introduce some additional
|
||||
# cycle to fill in the bitmask, which could be a big no-op.
|
||||
structured_output_request_ids[req.request_id] = i
|
||||
|
||||
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
|
||||
# Collect list of scheduled request ids that use structured output.
|
||||
# The corresponding rows of the bitmask will be in this order.
|
||||
# PERF: in case of chunked prefill,
|
||||
# request might not include any new tokens.
|
||||
# Therefore, we might introduce some additional
|
||||
# cycle to fill in the bitmask, which could be a big no-op.
|
||||
structured_output_request_ids = [
|
||||
req_id
|
||||
for req_id in scheduled_request_ids
|
||||
if (req := self.requests.get(req_id)) and req.use_structured_output
|
||||
]
|
||||
if not structured_output_request_ids:
|
||||
bitmask = None
|
||||
else:
|
||||
bitmask = self.structured_output_manager.grammar_bitmask(
|
||||
self.requests,
|
||||
structured_output_request_ids,
|
||||
scheduled_spec_decode_tokens,
|
||||
)
|
||||
return structured_output_request_ids, None
|
||||
|
||||
bitmask = self.structured_output_manager.grammar_bitmask(
|
||||
self.requests,
|
||||
structured_output_request_ids,
|
||||
scheduled_spec_decode_tokens,
|
||||
)
|
||||
return structured_output_request_ids, bitmask
|
||||
|
||||
def update_from_output(
|
||||
@@ -1013,12 +1010,10 @@ class Scheduler(SchedulerInterface):
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
if new_token_ids and self.structured_output_manager.should_advance(request):
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# checked above, so safe to ignore type warning
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
req_id, new_token_ids
|
||||
)
|
||||
struct_output_request = request.structured_output_request
|
||||
assert struct_output_request is not None
|
||||
assert struct_output_request.grammar is not None
|
||||
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
|
||||
|
||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||
|
||||
Reference in New Issue
Block a user