[V1][Feature] Enable Speculative Decoding with Structured Outputs (#14702)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Benjamin Chislett
2025-04-29 17:02:10 -07:00
committed by GitHub
parent 7489ec0bab
commit 34120f5acd
9 changed files with 207 additions and 57 deletions

View File

@@ -27,6 +27,7 @@ class StructuredOutputManager:
def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None
self.vllm_config = vllm_config
self._grammar_bitmask: Optional[torch.Tensor] = None
# The default max_workers if not specified is the number of CPUs * 5,
@@ -80,7 +81,7 @@ class StructuredOutputManager:
self,
requests: dict[str, Request],
structured_output_request_ids: dict[str, int],
batch_len: int,
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> Optional[npt.NDArray[np.int32]]:
# Prepare the structured output bitmask for this batch.
if not structured_output_request_ids:
@@ -88,20 +89,52 @@ class StructuredOutputManager:
if self._grammar_bitmask is None:
assert self.backend is not None
self._grammar_bitmask = self.backend.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs)
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
if self.vllm_config.speculative_config is not None:
max_num_spec_tokens = self.vllm_config.\
speculative_config.num_speculative_tokens
else:
max_num_spec_tokens = 0
# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.
bitmask_tensor = self._grammar_bitmask
for req_id, batch_index in structured_output_request_ids.items():
# Allocate a bitmask for each token needing to be checked:
# one for each speculative position, and one more for the
# bonus token / non-speculative token.
self._grammar_bitmask = \
self.backend.allocate_token_bitmask(
max_batch_size * (1 + max_num_spec_tokens))
# Generate a batched bitmask for all structured output requests.
# When speculative decoding is enabled, we need to include multiple
# masks for each request, one for each possible bonus token position.
# These are stored inline in the tensor and unpacked by the gpu runner.
cumulative_index = 0
ordered_seq = sorted(structured_output_request_ids.items(),
key=lambda x: x[1])
# NOTE: This outer loop can likely be parallelized to improve
# performance of bitmask generation for large batches.
for req_id, _ in ordered_seq:
request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None
if not request.grammar.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
if batch_len < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:batch_len]
state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
for i, token in enumerate(req_tokens):
if not request.grammar.is_terminated():
request.grammar.fill_bitmask(self._grammar_bitmask,
cumulative_index)
if token is not None:
# In order to generate the correct bitmask for each
# position in the speculative sequence, we advance
# the FSM state for each speculative token and rollback
# to restore the previous state when we are finished.
assert request.grammar.accept_tokens(req_id, [token])
state_advancements += 1
cumulative_index += 1
if state_advancements > 0:
request.grammar.rollback(state_advancements)
bitmask_tensor = self._grammar_bitmask
if cumulative_index < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:cumulative_index]
# After finishing with the xgrammar operations, we convert to
# np.ndarray, because that is much more efficient for serialization