[Core] Async scheduling + structured outputs compatibility (#26866)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import itertools
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
@@ -24,7 +24,12 @@ from vllm.v1.core.encoder_cache_manager import (
|
||||
)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.output import (
|
||||
CachedRequestData,
|
||||
GrammarOutput,
|
||||
NewRequestData,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
@@ -35,10 +40,6 @@ from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -619,9 +620,6 @@ class Scheduler(SchedulerInterface):
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
|
||||
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
|
||||
)
|
||||
|
||||
# Record the request ids that were scheduled in this step.
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
@@ -641,8 +639,6 @@ class Scheduler(SchedulerInterface):
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
structured_output_request_ids=structured_output_request_ids,
|
||||
grammar_bitmask=grammar_bitmask,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
@@ -872,9 +868,8 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
def get_grammar_bitmask(
|
||||
self,
|
||||
scheduled_request_ids: Iterable[str],
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> GrammarOutput | None:
|
||||
# Collect list of scheduled request ids that use structured output.
|
||||
# The corresponding rows of the bitmask will be in this order.
|
||||
# PERF: in case of chunked prefill,
|
||||
@@ -883,18 +878,18 @@ class Scheduler(SchedulerInterface):
|
||||
# cycle to fill in the bitmask, which could be a big no-op.
|
||||
structured_output_request_ids = [
|
||||
req_id
|
||||
for req_id in scheduled_request_ids
|
||||
for req_id in scheduler_output.num_scheduled_tokens
|
||||
if (req := self.requests.get(req_id)) and req.use_structured_output
|
||||
]
|
||||
if not structured_output_request_ids:
|
||||
return structured_output_request_ids, None
|
||||
return None
|
||||
|
||||
bitmask = self.structured_output_manager.grammar_bitmask(
|
||||
self.requests,
|
||||
structured_output_request_ids,
|
||||
scheduled_spec_decode_tokens,
|
||||
scheduler_output.scheduled_spec_decode_tokens,
|
||||
)
|
||||
return structured_output_request_ids, bitmask
|
||||
return GrammarOutput(structured_output_request_ids, bitmask)
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user