[V1][Feature] Enable Speculative Decoding with Structured Outputs (#14702)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
committed by
GitHub
parent
7489ec0bab
commit
34120f5acd
@@ -27,6 +27,7 @@ class StructuredOutputManager:
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.backend: Optional[StructuredOutputBackend] = None
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||
|
||||
# The default max_workers if not specified is the number of CPUs * 5,
|
||||
@@ -80,7 +81,7 @@ class StructuredOutputManager:
|
||||
self,
|
||||
requests: dict[str, Request],
|
||||
structured_output_request_ids: dict[str, int],
|
||||
batch_len: int,
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
) -> Optional[npt.NDArray[np.int32]]:
|
||||
# Prepare the structured output bitmask for this batch.
|
||||
if not structured_output_request_ids:
|
||||
@@ -88,20 +89,52 @@ class StructuredOutputManager:
|
||||
|
||||
if self._grammar_bitmask is None:
|
||||
assert self.backend is not None
|
||||
self._grammar_bitmask = self.backend.allocate_token_bitmask(
|
||||
self.vllm_config.scheduler_config.max_num_seqs)
|
||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
max_num_spec_tokens = self.vllm_config.\
|
||||
speculative_config.num_speculative_tokens
|
||||
else:
|
||||
max_num_spec_tokens = 0
|
||||
|
||||
# Fill the bitmask using the index of each request equal to its
|
||||
# position in the batch. Resize the bitmask down to the size of
|
||||
# the batch.
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
for req_id, batch_index in structured_output_request_ids.items():
|
||||
# Allocate a bitmask for each token needing to be checked:
|
||||
# one for each speculative position, and one more for the
|
||||
# bonus token / non-speculative token.
|
||||
self._grammar_bitmask = \
|
||||
self.backend.allocate_token_bitmask(
|
||||
max_batch_size * (1 + max_num_spec_tokens))
|
||||
|
||||
# Generate a batched bitmask for all structured output requests.
|
||||
# When speculative decoding is enabled, we need to include multiple
|
||||
# masks for each request, one for each possible bonus token position.
|
||||
# These are stored inline in the tensor and unpacked by the gpu runner.
|
||||
cumulative_index = 0
|
||||
ordered_seq = sorted(structured_output_request_ids.items(),
|
||||
key=lambda x: x[1])
|
||||
# NOTE: This outer loop can likely be parallelized to improve
|
||||
# performance of bitmask generation for large batches.
|
||||
for req_id, _ in ordered_seq:
|
||||
request = requests[req_id].structured_output_request
|
||||
assert request is not None and request.grammar is not None
|
||||
if not request.grammar.is_terminated():
|
||||
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
|
||||
if batch_len < self._grammar_bitmask.shape[0]:
|
||||
bitmask_tensor = self._grammar_bitmask[:batch_len]
|
||||
state_advancements = 0
|
||||
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
|
||||
for i, token in enumerate(req_tokens):
|
||||
if not request.grammar.is_terminated():
|
||||
request.grammar.fill_bitmask(self._grammar_bitmask,
|
||||
cumulative_index)
|
||||
if token is not None:
|
||||
# In order to generate the correct bitmask for each
|
||||
# position in the speculative sequence, we advance
|
||||
# the FSM state for each speculative token and rollback
|
||||
# to restore the previous state when we are finished.
|
||||
assert request.grammar.accept_tokens(req_id, [token])
|
||||
state_advancements += 1
|
||||
cumulative_index += 1
|
||||
if state_advancements > 0:
|
||||
request.grammar.rollback(state_advancements)
|
||||
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
if cumulative_index < self._grammar_bitmask.shape[0]:
|
||||
bitmask_tensor = self._grammar_bitmask[:cumulative_index]
|
||||
|
||||
# After finishing with the xgrammar operations, we convert to
|
||||
# np.ndarray, because that is much more efficient for serialization
|
||||
|
||||
Reference in New Issue
Block a user