[Core] Streamline some structured output related code (#26737)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-10-14 16:27:44 -07:00
committed by GitHub
parent a86b4c58e8
commit 4aed506b65
13 changed files with 121 additions and 138 deletions

View File

@@ -30,7 +30,6 @@ from vllm.v1.kv_cache_interface import (
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
@@ -335,10 +334,10 @@ def test_stop_via_update_from_output():
requests[0].request_id: [], requests[0].request_id: [],
requests[1].request_id: [10], requests[1].request_id: [10],
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -383,10 +382,10 @@ def test_stop_via_update_from_output():
requests[0].request_id: [10, 42], requests[0].request_id: [10, 42],
requests[1].request_id: [13], requests[1].request_id: [13],
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -429,10 +428,10 @@ def test_stop_via_update_from_output():
requests[0].request_id: [10, 11], requests[0].request_id: [10, 11],
requests[1].request_id: [], requests[1].request_id: [],
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -470,10 +469,10 @@ def test_stop_via_update_from_output():
total_num_scheduled_tokens=3, total_num_scheduled_tokens=3,
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]}, scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -1941,7 +1940,6 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
structured_output_request=StructuredOutputRequest(sampling_params),
) )
scheduler.add_request(request) scheduler.add_request(request)
output = scheduler.schedule() output = scheduler.schedule()

View File

@@ -26,7 +26,7 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(), kv_connector_metadata=SharedStorageConnectorMetadata(),
) )

View File

@@ -89,10 +89,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
total_num_scheduled_tokens=total_num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -168,10 +168,10 @@ def test_update_states_request_finished(model_runner):
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -198,10 +198,10 @@ def test_update_states_request_resumed(model_runner):
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -225,10 +225,10 @@ def test_update_states_request_resumed(model_runner):
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -256,10 +256,10 @@ def test_update_states_no_changes(model_runner):
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -291,10 +291,10 @@ def test_update_states_request_unscheduled(model_runner):
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )

View File

@@ -146,10 +146,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
total_num_scheduled_tokens=total_num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -212,10 +212,10 @@ def test_update_states_request_finished(model_runner, dist_init):
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -244,10 +244,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -273,10 +273,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -366,10 +366,10 @@ def test_update_states_no_changes(model_runner, dist_init):
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )
@@ -403,10 +403,10 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids=[],
grammar_bitmask=None, grammar_bitmask=None,
) )

View File

@@ -165,9 +165,8 @@ class SchedulerOutput:
# freed from the encoder cache. # freed from the encoder cache.
free_encoder_mm_hashes: list[str] free_encoder_mm_hashes: list[str]
# Dict of request ids to their index within the batch # ids of structured outputs requests included in the bitmask, in order.
# for filling the next token bitmask structured_output_request_ids: list[str]
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch # the bitmask for the whole batch
grammar_bitmask: "npt.NDArray[np.int32] | None" grammar_bitmask: "npt.NDArray[np.int32] | None"

View File

@@ -5,7 +5,7 @@ import itertools
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any from typing import TYPE_CHECKING, Any
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
@@ -34,6 +34,10 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -608,11 +612,8 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens, scheduled_spec_decode_tokens,
req_to_new_blocks, req_to_new_blocks,
) )
scheduled_requests = (
scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs
)
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
scheduled_requests, scheduled_spec_decode_tokens num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
@@ -876,27 +877,23 @@ class Scheduler(SchedulerInterface):
def get_grammar_bitmask( def get_grammar_bitmask(
self, self,
requests: list[Request], scheduled_request_ids: Iterable[str],
scheduled_spec_decode_tokens: dict[str, list[int]], scheduled_spec_decode_tokens: dict[str, list[int]],
): ) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
# NOTE: structured_output_request_ids maps # Collect list of scheduled request ids that use structured output.
# a request's (request that uses structured output) # The corresponding rows of the bitmask will be in this order.
# request_id to its index in the batch.
# This will help us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
for i, req in enumerate(requests):
if req.use_structured_output:
# PERF: in case of chunked prefill, # PERF: in case of chunked prefill,
# request might not include any new tokens. # request might not include any new tokens.
# Therefore, we might introduce some additional # Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[req.request_id] = i structured_output_request_ids = [
req_id
for req_id in scheduled_request_ids
if (req := self.requests.get(req_id)) and req.use_structured_output
]
if not structured_output_request_ids: if not structured_output_request_ids:
bitmask = None return structured_output_request_ids, None
else:
bitmask = self.structured_output_manager.grammar_bitmask( bitmask = self.structured_output_manager.grammar_bitmask(
self.requests, self.requests,
structured_output_request_ids, structured_output_request_ids,
@@ -1013,12 +1010,10 @@ class Scheduler(SchedulerInterface):
new_logprobs = logprobs.slice(req_index, req_index + 1) new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and self.structured_output_manager.should_advance(request): if new_token_ids and self.structured_output_manager.should_advance(request):
# NOTE: structured_output_request struct_output_request = request.structured_output_request
# should not be None if use_structured_output, we have assert struct_output_request is not None
# checked above, so safe to ignore type warning assert struct_output_request.grammar is not None
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
req_id, new_token_ids
)
if num_nans_in_logits is not None and req_id in num_nans_in_logits: if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]

View File

@@ -40,7 +40,6 @@ class Request:
prompt_embeds: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None,
mm_features: list[MultiModalFeatureSpec] | None = None, mm_features: list[MultiModalFeatureSpec] | None = None,
lora_request: Optional["LoRARequest"] = None, lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: str | None = None, cache_salt: str | None = None,
priority: int = 0, priority: int = 0,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
@@ -54,11 +53,12 @@ class Request:
# Because of LoRA, the eos token id can be different for each request. # Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.structured_output_request = structured_output_request self.structured_output_request = StructuredOutputRequest.from_sampling_params(
sampling_params
)
self.arrival_time = arrival_time if arrival_time is not None else time.time() self.arrival_time = arrival_time if arrival_time is not None else time.time()
self.status = RequestStatus.WAITING self.status = RequestStatus.WAITING
self.use_structured_output = False
self.events: list[EngineCoreEvent] = [] self.events: list[EngineCoreEvent] = []
self.stop_reason: int | str | None = None self.stop_reason: int | str | None = None
@@ -72,9 +72,8 @@ class Request:
# Generative models. # Generative models.
assert sampling_params.max_tokens is not None assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens self.max_tokens = sampling_params.max_tokens
if sampling_params.structured_outputs is not None: if self.structured_output_request is not None:
self.status = RequestStatus.WAITING_FOR_FSM self.status = RequestStatus.WAITING_FOR_FSM
self.use_structured_output = True
if sampling_params.extra_args is not None: if sampling_params.extra_args is not None:
self.kv_transfer_params = sampling_params.extra_args.get( self.kv_transfer_params = sampling_params.extra_args.get(
@@ -145,11 +144,6 @@ class Request:
eos_token_id=request.eos_token_id, eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time, arrival_time=request.arrival_time,
lora_request=request.lora_request, lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params
)
if request.sampling_params
else None,
cache_salt=request.cache_salt, cache_salt=request.cache_salt,
priority=request.priority, priority=request.priority,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
@@ -170,6 +164,10 @@ class Request:
if self.get_hash_new_full_blocks is not None: if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks()) self.block_hashes.extend(self.get_hash_new_full_blocks())
@property
def use_structured_output(self) -> bool:
return self.structured_output_request is not None
@property @property
def is_output_corrupted(self) -> bool: def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0 return self.num_nans_in_logits > 0

View File

@@ -167,7 +167,7 @@ class StructuredOutputManager:
def grammar_bitmask( def grammar_bitmask(
self, self,
requests: dict[str, Request], requests: dict[str, Request],
structured_output_request_ids: dict[str, int], structured_output_request_ids: list[str],
scheduled_spec_decode_tokens: dict[str, list[int]], scheduled_spec_decode_tokens: dict[str, list[int]],
) -> "npt.NDArray[np.int32] | None": ) -> "npt.NDArray[np.int32] | None":
# Prepare the structured output bitmask for this batch. # Prepare the structured output bitmask for this batch.
@@ -196,17 +196,16 @@ class StructuredOutputManager:
# masks for each request, one for each possible bonus token position. # masks for each request, one for each possible bonus token position.
# These are stored inline in the tensor and unpacked by the gpu runner. # These are stored inline in the tensor and unpacked by the gpu runner.
cumulative_index = 0 cumulative_index = 0
ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1])
# Optimized parallel filling of bitmasks for # Optimized parallel filling of bitmasks for
# non-spec, large-batch-size cases # non-spec, large-batch-size cases
if ( if (
len(ordered_seq) > self.fill_bitmask_parallel_threshold len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold
and max_num_spec_tokens == 0 and max_num_spec_tokens == 0
): ):
promises = [] promises = []
batch = [] batch = []
for req_id, _ in ordered_seq: for req_id in structured_output_request_ids:
request = requests[req_id] request = requests[req_id]
structured_output_request = request.structured_output_request structured_output_request = request.structured_output_request
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -230,7 +229,7 @@ class StructuredOutputManager:
promise.result() promise.result()
else: else:
# Fallback to serial filling of bitmasks for small-batch-size cases # Fallback to serial filling of bitmasks for small-batch-size cases
for req_id, _ in ordered_seq: for req_id in structured_output_request_ids:
request = requests[req_id] request = requests[req_id]
structured_output_request = request.structured_output_request structured_output_request = request.structured_output_request
@@ -295,9 +294,10 @@ class StructuredOutputManager:
assert request.structured_output_request.grammar is not None assert request.structured_output_request.grammar is not None
# by default, we should always advance # by default, we should always advance
# for cases that don't use thinking mode. # for cases that don't use thinking mode.
if self.reasoner is not None: if self.reasoner is None:
structured_req = request.structured_output_request return True
structured_req = request.structured_output_request
if structured_req.reasoning_ended: if structured_req.reasoning_ended:
return True return True
@@ -308,8 +308,6 @@ class StructuredOutputManager:
structured_req.reasoning_ended = True structured_req.reasoning_ended = True
return False return False
else:
return True
def clear_backend(self) -> None: def clear_backend(self) -> None:
if self.backend is not None: if self.backend is not None:

View File

@@ -252,7 +252,7 @@ def serialize_guidance_grammar(
def validate_guidance_grammar( def validate_guidance_grammar(
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
) -> None: ) -> None:
tp, grm = get_structured_output_key(sampling_params) tp, grm = get_structured_output_key(sampling_params.structured_outputs)
guidance_grm = serialize_guidance_grammar(tp, grm) guidance_grm = serialize_guidance_grammar(tp, grm)
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer) err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
if err: if err:

View File

@@ -7,7 +7,7 @@ from concurrent.futures import Future
from concurrent.futures._base import TimeoutError from concurrent.futures._base import TimeoutError
from typing import cast from typing import cast
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.structured_output.backend_types import ( from vllm.v1.structured_output.backend_types import (
StructuredOutputGrammar, StructuredOutputGrammar,
StructuredOutputKey, StructuredOutputKey,
@@ -17,10 +17,19 @@ from vllm.v1.structured_output.backend_types import (
@dataclasses.dataclass @dataclasses.dataclass
class StructuredOutputRequest: class StructuredOutputRequest:
sampling_params: SamplingParams params: StructuredOutputsParams
_grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None _grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None
reasoning_ended: bool | None = None reasoning_ended: bool | None = None
@staticmethod
def from_sampling_params(
sampling_params: SamplingParams | None,
) -> "StructuredOutputRequest | None":
if sampling_params is None:
return None
params = sampling_params.structured_outputs
return StructuredOutputRequest(params=params) if params else None
def _check_grammar_completion(self) -> bool: def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports # NOTE: We have to lazy import to gate circular imports
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
@@ -53,31 +62,28 @@ class StructuredOutputRequest:
@functools.cached_property @functools.cached_property
def structured_output_key(self) -> StructuredOutputKey: def structured_output_key(self) -> StructuredOutputKey:
return get_structured_output_key(self.sampling_params) return get_structured_output_key(self.params)
def get_structured_output_key(sampling_params: SamplingParams) -> StructuredOutputKey: def get_structured_output_key(params: StructuredOutputsParams) -> StructuredOutputKey:
params = sampling_params.structured_outputs
assert params is not None, "params can't be None."
if params.json is not None: if params.json is not None:
if not isinstance(params.json, str): if not isinstance(params.json, str):
json_str = json.dumps(params.json) json_str = json.dumps(params.json)
else: else:
json_str = params.json json_str = params.json
return (StructuredOutputOptions.JSON, json_str) return StructuredOutputOptions.JSON, json_str
elif params.json_object: if params.json_object:
return (StructuredOutputOptions.JSON_OBJECT, "") return StructuredOutputOptions.JSON_OBJECT, ""
elif params.regex is not None: if params.regex is not None:
return (StructuredOutputOptions.REGEX, params.regex) return StructuredOutputOptions.REGEX, params.regex
elif params.choice is not None: if params.choice is not None:
if not isinstance(params.choice, str): if not isinstance(params.choice, str):
json_str = json.dumps(params.choice) json_str = json.dumps(params.choice)
else: else:
json_str = params.choice json_str = params.choice
return (StructuredOutputOptions.CHOICE, json_str) return StructuredOutputOptions.CHOICE, json_str
elif params.grammar is not None: if params.grammar is not None:
return (StructuredOutputOptions.GRAMMAR, params.grammar) return StructuredOutputOptions.GRAMMAR, params.grammar
elif params.structural_tag is not None: if params.structural_tag is not None:
return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag) return StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag
else:
raise ValueError("No valid structured output parameter found") raise ValueError("No valid structured output parameter found")

View File

@@ -47,7 +47,6 @@ def apply_grammar_bitmask(
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
input_batch: InputBatch, input_batch: InputBatch,
logits: torch.Tensor, logits: torch.Tensor,
device: torch.device,
) -> None: ) -> None:
""" """
Apply grammar bitmask to output logits of the model with xgrammar function. Apply grammar bitmask to output logits of the model with xgrammar function.
@@ -56,7 +55,6 @@ def apply_grammar_bitmask(
scheduler_output (SchedulerOutput): The result of engine scheduling. scheduler_output (SchedulerOutput): The result of engine scheduling.
input_batch (InputBatch): The input of model runner. input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward. logits (torch.Tensor): The output logits of model forward.
device (torch.device): The device that model runner running on.
""" """
grammar_bitmask = scheduler_output.grammar_bitmask grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None: if grammar_bitmask is None:
@@ -91,10 +89,7 @@ def apply_grammar_bitmask(
dtype=grammar_bitmask.dtype, dtype=grammar_bitmask.dtype,
) )
cumulative_index = 0 cumulative_index = 0
seq = sorted( for req_id in scheduler_output.structured_output_request_ids:
scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1]
)
for req_id, _ in seq:
num_spec_tokens = len( num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
) )
@@ -117,7 +112,7 @@ def apply_grammar_bitmask(
xgr.apply_token_bitmask_inplace( xgr.apply_token_bitmask_inplace(
logits, logits,
grammar_bitmask.to(device, non_blocking=True), grammar_bitmask.to(logits.device, non_blocking=True),
indices=out_indices if not skip_out_indices else None, indices=out_indices if not skip_out_indices else None,
) )

View File

@@ -2568,10 +2568,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits = model_output_broadcast_data["logits"] logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present # Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None: if scheduler_output.structured_output_request_ids:
apply_grammar_bitmask( apply_grammar_bitmask(scheduler_output, self.input_batch, logits)
scheduler_output, self.input_batch, logits, self.device
)
with record_function_or_nullcontext("Sample"): with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata) sampler_output = self._sample(logits, spec_decode_metadata)

View File

@@ -1963,12 +1963,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.grammar_bitmask_cpu.zero_() self.grammar_bitmask_cpu.zero_()
self.require_structured_out_cpu.zero_() self.require_structured_out_cpu.zero_()
sorted_struct_requests = sorted(
scheduler_output.structured_output_request_ids.items(),
key=lambda item: item[1],
)
cumulative_mask_idx = 0 cumulative_mask_idx = 0
for req_id, _ in sorted_struct_requests: for req_id in scheduler_output.structured_output_request_ids:
if req_id not in self.input_batch.req_id_to_index: if req_id not in self.input_batch.req_id_to_index:
continue continue
batch_index = self.input_batch.req_id_to_index[req_id] batch_index = self.input_batch.req_id_to_index[req_id]