[Core] Streamline some structured output related code (#26737)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -30,7 +30,6 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
|
||||||
|
|
||||||
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
|
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
|
||||||
|
|
||||||
@@ -335,10 +334,10 @@ def test_stop_via_update_from_output():
|
|||||||
requests[0].request_id: [],
|
requests[0].request_id: [],
|
||||||
requests[1].request_id: [10],
|
requests[1].request_id: [10],
|
||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -383,10 +382,10 @@ def test_stop_via_update_from_output():
|
|||||||
requests[0].request_id: [10, 42],
|
requests[0].request_id: [10, 42],
|
||||||
requests[1].request_id: [13],
|
requests[1].request_id: [13],
|
||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -429,10 +428,10 @@ def test_stop_via_update_from_output():
|
|||||||
requests[0].request_id: [10, 11],
|
requests[0].request_id: [10, 11],
|
||||||
requests[1].request_id: [],
|
requests[1].request_id: [],
|
||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -470,10 +469,10 @@ def test_stop_via_update_from_output():
|
|||||||
total_num_scheduled_tokens=3,
|
total_num_scheduled_tokens=3,
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
|
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1941,7 +1940,6 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
pooling_params=None,
|
pooling_params=None,
|
||||||
eos_token_id=EOS_TOKEN_ID,
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
structured_output_request=StructuredOutputRequest(sampling_params),
|
|
||||||
)
|
)
|
||||||
scheduler.add_request(request)
|
scheduler.add_request(request)
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def _make_empty_scheduler_output():
|
|||||||
num_common_prefix_blocks=[],
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
kv_connector_metadata=SharedStorageConnectorMetadata(),
|
kv_connector_metadata=SharedStorageConnectorMetadata(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -89,10 +89,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -168,10 +168,10 @@ def test_update_states_request_finished(model_runner):
|
|||||||
total_num_scheduled_tokens=0,
|
total_num_scheduled_tokens=0,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids={req_id},
|
finished_req_ids={req_id},
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -198,10 +198,10 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
total_num_scheduled_tokens=0,
|
total_num_scheduled_tokens=0,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -225,10 +225,10 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -256,10 +256,10 @@ def test_update_states_no_changes(model_runner):
|
|||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -291,10 +291,10 @@ def test_update_states_request_unscheduled(model_runner):
|
|||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -146,10 +146,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -212,10 +212,10 @@ def test_update_states_request_finished(model_runner, dist_init):
|
|||||||
total_num_scheduled_tokens=0,
|
total_num_scheduled_tokens=0,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids={req_id},
|
finished_req_ids={req_id},
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -244,10 +244,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
|||||||
total_num_scheduled_tokens=0,
|
total_num_scheduled_tokens=0,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -273,10 +273,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
|||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -366,10 +366,10 @@ def test_update_states_no_changes(model_runner, dist_init):
|
|||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -403,10 +403,10 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
|
|||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=[],
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_mm_hashes=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids=[],
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -165,9 +165,8 @@ class SchedulerOutput:
|
|||||||
# freed from the encoder cache.
|
# freed from the encoder cache.
|
||||||
free_encoder_mm_hashes: list[str]
|
free_encoder_mm_hashes: list[str]
|
||||||
|
|
||||||
# Dict of request ids to their index within the batch
|
# ids of structured outputs requests included in the bitmask, in order.
|
||||||
# for filling the next token bitmask
|
structured_output_request_ids: list[str]
|
||||||
structured_output_request_ids: dict[str, int]
|
|
||||||
# the bitmask for the whole batch
|
# the bitmask for the whole batch
|
||||||
grammar_bitmask: "npt.NDArray[np.int32] | None"
|
grammar_bitmask: "npt.NDArray[np.int32] | None"
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import itertools
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||||
@@ -34,6 +34,10 @@ from vllm.v1.request import Request, RequestStatus
|
|||||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -608,11 +612,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
scheduled_spec_decode_tokens,
|
scheduled_spec_decode_tokens,
|
||||||
req_to_new_blocks,
|
req_to_new_blocks,
|
||||||
)
|
)
|
||||||
scheduled_requests = (
|
|
||||||
scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs
|
|
||||||
)
|
|
||||||
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
|
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
|
||||||
scheduled_requests, scheduled_spec_decode_tokens
|
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
|
||||||
)
|
)
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=new_reqs_data,
|
scheduled_new_reqs=new_reqs_data,
|
||||||
@@ -876,27 +877,23 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
def get_grammar_bitmask(
|
def get_grammar_bitmask(
|
||||||
self,
|
self,
|
||||||
requests: list[Request],
|
scheduled_request_ids: Iterable[str],
|
||||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||||
):
|
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
|
||||||
# NOTE: structured_output_request_ids maps
|
# Collect list of scheduled request ids that use structured output.
|
||||||
# a request's (request that uses structured output)
|
# The corresponding rows of the bitmask will be in this order.
|
||||||
# request_id to its index in the batch.
|
|
||||||
# This will help us determine to slice the grammar bitmask
|
|
||||||
# and only applies valid mask for requests that
|
|
||||||
# uses structured decoding.
|
|
||||||
structured_output_request_ids: dict[str, int] = {}
|
|
||||||
for i, req in enumerate(requests):
|
|
||||||
if req.use_structured_output:
|
|
||||||
# PERF: in case of chunked prefill,
|
# PERF: in case of chunked prefill,
|
||||||
# request might not include any new tokens.
|
# request might not include any new tokens.
|
||||||
# Therefore, we might introduce some additional
|
# Therefore, we might introduce some additional
|
||||||
# cycle to fill in the bitmask, which could be a big no-op.
|
# cycle to fill in the bitmask, which could be a big no-op.
|
||||||
structured_output_request_ids[req.request_id] = i
|
structured_output_request_ids = [
|
||||||
|
req_id
|
||||||
|
for req_id in scheduled_request_ids
|
||||||
|
if (req := self.requests.get(req_id)) and req.use_structured_output
|
||||||
|
]
|
||||||
if not structured_output_request_ids:
|
if not structured_output_request_ids:
|
||||||
bitmask = None
|
return structured_output_request_ids, None
|
||||||
else:
|
|
||||||
bitmask = self.structured_output_manager.grammar_bitmask(
|
bitmask = self.structured_output_manager.grammar_bitmask(
|
||||||
self.requests,
|
self.requests,
|
||||||
structured_output_request_ids,
|
structured_output_request_ids,
|
||||||
@@ -1013,12 +1010,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||||
|
|
||||||
if new_token_ids and self.structured_output_manager.should_advance(request):
|
if new_token_ids and self.structured_output_manager.should_advance(request):
|
||||||
# NOTE: structured_output_request
|
struct_output_request = request.structured_output_request
|
||||||
# should not be None if use_structured_output, we have
|
assert struct_output_request is not None
|
||||||
# checked above, so safe to ignore type warning
|
assert struct_output_request.grammar is not None
|
||||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
|
||||||
req_id, new_token_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ class Request:
|
|||||||
prompt_embeds: torch.Tensor | None = None,
|
prompt_embeds: torch.Tensor | None = None,
|
||||||
mm_features: list[MultiModalFeatureSpec] | None = None,
|
mm_features: list[MultiModalFeatureSpec] | None = None,
|
||||||
lora_request: Optional["LoRARequest"] = None,
|
lora_request: Optional["LoRARequest"] = None,
|
||||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
|
||||||
cache_salt: str | None = None,
|
cache_salt: str | None = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
trace_headers: Mapping[str, str] | None = None,
|
trace_headers: Mapping[str, str] | None = None,
|
||||||
@@ -54,11 +53,12 @@ class Request:
|
|||||||
# Because of LoRA, the eos token id can be different for each request.
|
# Because of LoRA, the eos token id can be different for each request.
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.structured_output_request = structured_output_request
|
self.structured_output_request = StructuredOutputRequest.from_sampling_params(
|
||||||
|
sampling_params
|
||||||
|
)
|
||||||
self.arrival_time = arrival_time if arrival_time is not None else time.time()
|
self.arrival_time = arrival_time if arrival_time is not None else time.time()
|
||||||
|
|
||||||
self.status = RequestStatus.WAITING
|
self.status = RequestStatus.WAITING
|
||||||
self.use_structured_output = False
|
|
||||||
self.events: list[EngineCoreEvent] = []
|
self.events: list[EngineCoreEvent] = []
|
||||||
self.stop_reason: int | str | None = None
|
self.stop_reason: int | str | None = None
|
||||||
|
|
||||||
@@ -72,9 +72,8 @@ class Request:
|
|||||||
# Generative models.
|
# Generative models.
|
||||||
assert sampling_params.max_tokens is not None
|
assert sampling_params.max_tokens is not None
|
||||||
self.max_tokens = sampling_params.max_tokens
|
self.max_tokens = sampling_params.max_tokens
|
||||||
if sampling_params.structured_outputs is not None:
|
if self.structured_output_request is not None:
|
||||||
self.status = RequestStatus.WAITING_FOR_FSM
|
self.status = RequestStatus.WAITING_FOR_FSM
|
||||||
self.use_structured_output = True
|
|
||||||
|
|
||||||
if sampling_params.extra_args is not None:
|
if sampling_params.extra_args is not None:
|
||||||
self.kv_transfer_params = sampling_params.extra_args.get(
|
self.kv_transfer_params = sampling_params.extra_args.get(
|
||||||
@@ -145,11 +144,6 @@ class Request:
|
|||||||
eos_token_id=request.eos_token_id,
|
eos_token_id=request.eos_token_id,
|
||||||
arrival_time=request.arrival_time,
|
arrival_time=request.arrival_time,
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
structured_output_request=StructuredOutputRequest(
|
|
||||||
sampling_params=request.sampling_params
|
|
||||||
)
|
|
||||||
if request.sampling_params
|
|
||||||
else None,
|
|
||||||
cache_salt=request.cache_salt,
|
cache_salt=request.cache_salt,
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
trace_headers=request.trace_headers,
|
trace_headers=request.trace_headers,
|
||||||
@@ -170,6 +164,10 @@ class Request:
|
|||||||
if self.get_hash_new_full_blocks is not None:
|
if self.get_hash_new_full_blocks is not None:
|
||||||
self.block_hashes.extend(self.get_hash_new_full_blocks())
|
self.block_hashes.extend(self.get_hash_new_full_blocks())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_structured_output(self) -> bool:
|
||||||
|
return self.structured_output_request is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_output_corrupted(self) -> bool:
|
def is_output_corrupted(self) -> bool:
|
||||||
return self.num_nans_in_logits > 0
|
return self.num_nans_in_logits > 0
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ class StructuredOutputManager:
|
|||||||
def grammar_bitmask(
|
def grammar_bitmask(
|
||||||
self,
|
self,
|
||||||
requests: dict[str, Request],
|
requests: dict[str, Request],
|
||||||
structured_output_request_ids: dict[str, int],
|
structured_output_request_ids: list[str],
|
||||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||||
) -> "npt.NDArray[np.int32] | None":
|
) -> "npt.NDArray[np.int32] | None":
|
||||||
# Prepare the structured output bitmask for this batch.
|
# Prepare the structured output bitmask for this batch.
|
||||||
@@ -196,17 +196,16 @@ class StructuredOutputManager:
|
|||||||
# masks for each request, one for each possible bonus token position.
|
# masks for each request, one for each possible bonus token position.
|
||||||
# These are stored inline in the tensor and unpacked by the gpu runner.
|
# These are stored inline in the tensor and unpacked by the gpu runner.
|
||||||
cumulative_index = 0
|
cumulative_index = 0
|
||||||
ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1])
|
|
||||||
|
|
||||||
# Optimized parallel filling of bitmasks for
|
# Optimized parallel filling of bitmasks for
|
||||||
# non-spec, large-batch-size cases
|
# non-spec, large-batch-size cases
|
||||||
if (
|
if (
|
||||||
len(ordered_seq) > self.fill_bitmask_parallel_threshold
|
len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold
|
||||||
and max_num_spec_tokens == 0
|
and max_num_spec_tokens == 0
|
||||||
):
|
):
|
||||||
promises = []
|
promises = []
|
||||||
batch = []
|
batch = []
|
||||||
for req_id, _ in ordered_seq:
|
for req_id in structured_output_request_ids:
|
||||||
request = requests[req_id]
|
request = requests[req_id]
|
||||||
structured_output_request = request.structured_output_request
|
structured_output_request = request.structured_output_request
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -230,7 +229,7 @@ class StructuredOutputManager:
|
|||||||
promise.result()
|
promise.result()
|
||||||
else:
|
else:
|
||||||
# Fallback to serial filling of bitmasks for small-batch-size cases
|
# Fallback to serial filling of bitmasks for small-batch-size cases
|
||||||
for req_id, _ in ordered_seq:
|
for req_id in structured_output_request_ids:
|
||||||
request = requests[req_id]
|
request = requests[req_id]
|
||||||
structured_output_request = request.structured_output_request
|
structured_output_request = request.structured_output_request
|
||||||
|
|
||||||
@@ -295,9 +294,10 @@ class StructuredOutputManager:
|
|||||||
assert request.structured_output_request.grammar is not None
|
assert request.structured_output_request.grammar is not None
|
||||||
# by default, we should always advance
|
# by default, we should always advance
|
||||||
# for cases that don't use thinking mode.
|
# for cases that don't use thinking mode.
|
||||||
if self.reasoner is not None:
|
if self.reasoner is None:
|
||||||
structured_req = request.structured_output_request
|
return True
|
||||||
|
|
||||||
|
structured_req = request.structured_output_request
|
||||||
if structured_req.reasoning_ended:
|
if structured_req.reasoning_ended:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -308,8 +308,6 @@ class StructuredOutputManager:
|
|||||||
structured_req.reasoning_ended = True
|
structured_req.reasoning_ended = True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def clear_backend(self) -> None:
|
def clear_backend(self) -> None:
|
||||||
if self.backend is not None:
|
if self.backend is not None:
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ def serialize_guidance_grammar(
|
|||||||
def validate_guidance_grammar(
|
def validate_guidance_grammar(
|
||||||
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
|
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
tp, grm = get_structured_output_key(sampling_params)
|
tp, grm = get_structured_output_key(sampling_params.structured_outputs)
|
||||||
guidance_grm = serialize_guidance_grammar(tp, grm)
|
guidance_grm = serialize_guidance_grammar(tp, grm)
|
||||||
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
|
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
|
||||||
if err:
|
if err:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from concurrent.futures import Future
|
|||||||
from concurrent.futures._base import TimeoutError
|
from concurrent.futures._base import TimeoutError
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||||
from vllm.v1.structured_output.backend_types import (
|
from vllm.v1.structured_output.backend_types import (
|
||||||
StructuredOutputGrammar,
|
StructuredOutputGrammar,
|
||||||
StructuredOutputKey,
|
StructuredOutputKey,
|
||||||
@@ -17,10 +17,19 @@ from vllm.v1.structured_output.backend_types import (
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class StructuredOutputRequest:
|
class StructuredOutputRequest:
|
||||||
sampling_params: SamplingParams
|
params: StructuredOutputsParams
|
||||||
_grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None
|
_grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None
|
||||||
reasoning_ended: bool | None = None
|
reasoning_ended: bool | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_sampling_params(
|
||||||
|
sampling_params: SamplingParams | None,
|
||||||
|
) -> "StructuredOutputRequest | None":
|
||||||
|
if sampling_params is None:
|
||||||
|
return None
|
||||||
|
params = sampling_params.structured_outputs
|
||||||
|
return StructuredOutputRequest(params=params) if params else None
|
||||||
|
|
||||||
def _check_grammar_completion(self) -> bool:
|
def _check_grammar_completion(self) -> bool:
|
||||||
# NOTE: We have to lazy import to gate circular imports
|
# NOTE: We have to lazy import to gate circular imports
|
||||||
from vllm.v1.request import RequestStatus
|
from vllm.v1.request import RequestStatus
|
||||||
@@ -53,31 +62,28 @@ class StructuredOutputRequest:
|
|||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def structured_output_key(self) -> StructuredOutputKey:
|
def structured_output_key(self) -> StructuredOutputKey:
|
||||||
return get_structured_output_key(self.sampling_params)
|
return get_structured_output_key(self.params)
|
||||||
|
|
||||||
|
|
||||||
def get_structured_output_key(sampling_params: SamplingParams) -> StructuredOutputKey:
|
def get_structured_output_key(params: StructuredOutputsParams) -> StructuredOutputKey:
|
||||||
params = sampling_params.structured_outputs
|
|
||||||
assert params is not None, "params can't be None."
|
|
||||||
if params.json is not None:
|
if params.json is not None:
|
||||||
if not isinstance(params.json, str):
|
if not isinstance(params.json, str):
|
||||||
json_str = json.dumps(params.json)
|
json_str = json.dumps(params.json)
|
||||||
else:
|
else:
|
||||||
json_str = params.json
|
json_str = params.json
|
||||||
return (StructuredOutputOptions.JSON, json_str)
|
return StructuredOutputOptions.JSON, json_str
|
||||||
elif params.json_object:
|
if params.json_object:
|
||||||
return (StructuredOutputOptions.JSON_OBJECT, "")
|
return StructuredOutputOptions.JSON_OBJECT, ""
|
||||||
elif params.regex is not None:
|
if params.regex is not None:
|
||||||
return (StructuredOutputOptions.REGEX, params.regex)
|
return StructuredOutputOptions.REGEX, params.regex
|
||||||
elif params.choice is not None:
|
if params.choice is not None:
|
||||||
if not isinstance(params.choice, str):
|
if not isinstance(params.choice, str):
|
||||||
json_str = json.dumps(params.choice)
|
json_str = json.dumps(params.choice)
|
||||||
else:
|
else:
|
||||||
json_str = params.choice
|
json_str = params.choice
|
||||||
return (StructuredOutputOptions.CHOICE, json_str)
|
return StructuredOutputOptions.CHOICE, json_str
|
||||||
elif params.grammar is not None:
|
if params.grammar is not None:
|
||||||
return (StructuredOutputOptions.GRAMMAR, params.grammar)
|
return StructuredOutputOptions.GRAMMAR, params.grammar
|
||||||
elif params.structural_tag is not None:
|
if params.structural_tag is not None:
|
||||||
return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag)
|
return StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag
|
||||||
else:
|
|
||||||
raise ValueError("No valid structured output parameter found")
|
raise ValueError("No valid structured output parameter found")
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ def apply_grammar_bitmask(
|
|||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
input_batch: InputBatch,
|
input_batch: InputBatch,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
device: torch.device,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Apply grammar bitmask to output logits of the model with xgrammar function.
|
Apply grammar bitmask to output logits of the model with xgrammar function.
|
||||||
@@ -56,7 +55,6 @@ def apply_grammar_bitmask(
|
|||||||
scheduler_output (SchedulerOutput): The result of engine scheduling.
|
scheduler_output (SchedulerOutput): The result of engine scheduling.
|
||||||
input_batch (InputBatch): The input of model runner.
|
input_batch (InputBatch): The input of model runner.
|
||||||
logits (torch.Tensor): The output logits of model forward.
|
logits (torch.Tensor): The output logits of model forward.
|
||||||
device (torch.device): The device that model runner running on.
|
|
||||||
"""
|
"""
|
||||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||||
if grammar_bitmask is None:
|
if grammar_bitmask is None:
|
||||||
@@ -91,10 +89,7 @@ def apply_grammar_bitmask(
|
|||||||
dtype=grammar_bitmask.dtype,
|
dtype=grammar_bitmask.dtype,
|
||||||
)
|
)
|
||||||
cumulative_index = 0
|
cumulative_index = 0
|
||||||
seq = sorted(
|
for req_id in scheduler_output.structured_output_request_ids:
|
||||||
scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1]
|
|
||||||
)
|
|
||||||
for req_id, _ in seq:
|
|
||||||
num_spec_tokens = len(
|
num_spec_tokens = len(
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||||
)
|
)
|
||||||
@@ -117,7 +112,7 @@ def apply_grammar_bitmask(
|
|||||||
|
|
||||||
xgr.apply_token_bitmask_inplace(
|
xgr.apply_token_bitmask_inplace(
|
||||||
logits,
|
logits,
|
||||||
grammar_bitmask.to(device, non_blocking=True),
|
grammar_bitmask.to(logits.device, non_blocking=True),
|
||||||
indices=out_indices if not skip_out_indices else None,
|
indices=out_indices if not skip_out_indices else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2568,10 +2568,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logits = model_output_broadcast_data["logits"]
|
logits = model_output_broadcast_data["logits"]
|
||||||
|
|
||||||
# Apply structured output bitmasks if present
|
# Apply structured output bitmasks if present
|
||||||
if scheduler_output.grammar_bitmask is not None:
|
if scheduler_output.structured_output_request_ids:
|
||||||
apply_grammar_bitmask(
|
apply_grammar_bitmask(scheduler_output, self.input_batch, logits)
|
||||||
scheduler_output, self.input_batch, logits, self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
with record_function_or_nullcontext("Sample"):
|
with record_function_or_nullcontext("Sample"):
|
||||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||||
|
|||||||
@@ -1963,12 +1963,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.grammar_bitmask_cpu.zero_()
|
self.grammar_bitmask_cpu.zero_()
|
||||||
self.require_structured_out_cpu.zero_()
|
self.require_structured_out_cpu.zero_()
|
||||||
|
|
||||||
sorted_struct_requests = sorted(
|
|
||||||
scheduler_output.structured_output_request_ids.items(),
|
|
||||||
key=lambda item: item[1],
|
|
||||||
)
|
|
||||||
cumulative_mask_idx = 0
|
cumulative_mask_idx = 0
|
||||||
for req_id, _ in sorted_struct_requests:
|
for req_id in scheduler_output.structured_output_request_ids:
|
||||||
if req_id not in self.input_batch.req_id_to_index:
|
if req_id not in self.input_batch.req_id_to_index:
|
||||||
continue
|
continue
|
||||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
|||||||
Reference in New Issue
Block a user