[V1][Feature] Enable Speculative Decoding with Structured Outputs (#14702)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
committed by
GitHub
parent
7489ec0bab
commit
34120f5acd
@@ -957,46 +957,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
logits: torch.Tensor,
|
||||
):
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||
if grammar_bitmask is None:
|
||||
return
|
||||
|
||||
# We receive the structured output bitmask from the scheduler, but the
|
||||
# indices of the requests in the batch may not match the indices of
|
||||
# the bitmask since the scheduler doesn't know how the gpu runner is
|
||||
# ordering the requests in the batch. We need to sort the bitmask to
|
||||
# match the order of the requests used here.
|
||||
# We receive the structured output bitmask from the scheduler,
|
||||
# compacted to contain bitmasks only for structured output requests.
|
||||
# The order of the requests in the bitmask is not guaranteed to be the
|
||||
# same as the order of the requests in the gpu runner's batch. We need
|
||||
# to sort the bitmask to match the order of the requests used here.
|
||||
|
||||
# Get the batch indices of the structured output requests.
|
||||
# Keep track of the number of speculative tokens scheduled for every
|
||||
# request in the batch, as the logit indices are offset by this amount.
|
||||
struct_out_req_batch_indices: dict[str, int] = {}
|
||||
indices_match = True
|
||||
for req_id in self.input_batch.req_ids:
|
||||
mask_index = scheduler_output.structured_output_request_ids.get(
|
||||
req_id)
|
||||
if mask_index is None:
|
||||
# not a structured output request
|
||||
continue
|
||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||
if batch_index != mask_index:
|
||||
indices_match = False
|
||||
struct_out_req_batch_indices[req_id] = batch_index
|
||||
cumulative_offset = 0
|
||||
seq = sorted(self.input_batch.req_id_to_index.items(),
|
||||
key=lambda x: x[1])
|
||||
for req_id, batch_index in seq:
|
||||
logit_index = batch_index + cumulative_offset
|
||||
cumulative_offset += len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||
if req_id in scheduler_output.structured_output_request_ids:
|
||||
struct_out_req_batch_indices[req_id] = logit_index
|
||||
|
||||
if not indices_match:
|
||||
# Sort the bitmask to match the order of the requests
|
||||
sorted_bitmask = np.zeros_like(grammar_bitmask)
|
||||
for req_id, batch_index in struct_out_req_batch_indices.items():
|
||||
orig_index = scheduler_output.structured_output_request_ids[
|
||||
req_id]
|
||||
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
|
||||
grammar_bitmask = sorted_bitmask
|
||||
out_indices = []
|
||||
|
||||
# Reorder the bitmask to match the order of the requests in the batch.
|
||||
sorted_bitmask = np.zeros_like(grammar_bitmask,
|
||||
shape=(logits.shape[0],
|
||||
grammar_bitmask.shape[1]))
|
||||
cumulative_index = 0
|
||||
seq = sorted(scheduler_output.structured_output_request_ids.items(),
|
||||
key=lambda x: x[1])
|
||||
for req_id, _ in seq:
|
||||
logit_index = struct_out_req_batch_indices[req_id]
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||
for i in range(1 + num_spec_tokens):
|
||||
sorted_bitmask[logit_index + i] = \
|
||||
grammar_bitmask[cumulative_index + i]
|
||||
out_indices.append(logit_index + i)
|
||||
cumulative_index += 1 + num_spec_tokens
|
||||
grammar_bitmask = sorted_bitmask
|
||||
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
||||
|
||||
# TODO: compatibility with spec decode
|
||||
xgr.apply_token_bitmask_inplace(
|
||||
logits,
|
||||
grammar_bitmask.to(self.device, non_blocking=True),
|
||||
indices=list(struct_out_req_batch_indices.values()),
|
||||
indices=out_indices,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user