[Core] Async Scheduling X Spec Decoding Compatibility (#24799)
Signed-off-by: Ronald1995 <ronaldautomobile@163.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
@@ -16,18 +16,25 @@ class AsyncScheduler(Scheduler):
|
||||
) -> None:
|
||||
super()._update_after_schedule(scheduler_output)
|
||||
pending_structured_output_tokens = False
|
||||
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
for req_id in scheduler_output.num_scheduled_tokens:
|
||||
request = self.requests[req_id]
|
||||
pending_structured_output_tokens |= (
|
||||
request.use_structured_output and request.num_output_placeholders > 0
|
||||
)
|
||||
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
|
||||
if (
|
||||
request.num_computed_tokens
|
||||
== request.num_tokens + request.num_output_placeholders
|
||||
== request.num_tokens
|
||||
+ request.num_output_placeholders
|
||||
+ cur_num_spec_tokens
|
||||
):
|
||||
# The request will generate a new token in this scheduling step.
|
||||
# TODO(woosuk): Support speculative decoding.
|
||||
request.num_output_placeholders += 1
|
||||
# The request will generate a new token plus num_spec_tokens
|
||||
# in this scheduling step.
|
||||
request.num_output_placeholders += 1 + cur_num_spec_tokens
|
||||
# Add placeholders for the new tokens in spec_token_ids.
|
||||
# Wwe will update the actual spec token ids in the worker process.
|
||||
request.spec_token_ids = [-1] * self.num_spec_tokens
|
||||
|
||||
scheduler_output.pending_structured_output_tokens = (
|
||||
pending_structured_output_tokens
|
||||
|
||||
@@ -348,7 +348,10 @@ class Scheduler(SchedulerInterface):
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (
|
||||
num_new_tokens + request.num_computed_tokens - request.num_tokens
|
||||
num_new_tokens
|
||||
+ request.num_computed_tokens
|
||||
- request.num_tokens
|
||||
- request.num_output_placeholders
|
||||
)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||
@@ -1024,7 +1027,12 @@ class Scheduler(SchedulerInterface):
|
||||
# tokens and rejections. If some tokens are rejected,
|
||||
# num_computed_tokens is decreased by the number of rejected
|
||||
# tokens.
|
||||
request.num_computed_tokens -= num_rejected
|
||||
if request.num_computed_tokens > 0:
|
||||
request.num_computed_tokens -= num_rejected
|
||||
# If async scheduling, num_output_placeholders also includes
|
||||
# the scheduled spec tokens count and so is similarly adjusted.
|
||||
if request.num_output_placeholders > 0:
|
||||
request.num_output_placeholders -= num_rejected
|
||||
spec_decoding_stats = self.make_spec_decoding_stats(
|
||||
spec_decoding_stats,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
|
||||
Reference in New Issue
Block a user