diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 3c66a2320..23c610f3b 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -10,6 +10,11 @@ logger = init_logger(__name__) class AsyncScheduler(Scheduler): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # reusable read-only placeholder list for speculative decoding. + self._spec_token_placeholders: list[int] = [-1] * self.num_spec_tokens + def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: super()._update_after_schedule(scheduler_output) has_structured_output_requests = False @@ -31,9 +36,9 @@ class AsyncScheduler(Scheduler): # The request will generate a new token plus num_spec_tokens # in this scheduling step. request.num_output_placeholders += 1 + cur_num_spec_tokens - # Add placeholders for the new tokens in spec_token_ids. + # Add placeholders for the new draft/spec tokens. # We will update the actual spec token ids in the worker process. - request.spec_token_ids = [-1] * self.num_spec_tokens + request.spec_token_ids = self._spec_token_placeholders scheduler_output.has_structured_output_requests = has_structured_output_requests scheduler_output.pending_structured_output_tokens = ( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1544d847c..869b53601 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -487,9 +487,11 @@ class Scheduler(SchedulerInterface): - request.num_output_placeholders ) if num_scheduled_spec_tokens > 0: - # Trim spec_token_ids list to num_scheduled_spec_tokens. - del request.spec_token_ids[num_scheduled_spec_tokens:] - scheduled_spec_decode_tokens[request_id] = request.spec_token_ids + spec_token_ids = request.spec_token_ids + if len(spec_token_ids) > num_scheduled_spec_tokens: + spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens] + scheduled_spec_decode_tokens[request.request_id] = spec_token_ids + # New spec tokens will be set in `update_draft_token_ids` before the # next step when applicable. request.spec_token_ids = [] @@ -887,7 +889,8 @@ class Scheduler(SchedulerInterface): self.encoder_cache_manager.free(request) request.status = RequestStatus.PREEMPTED request.num_computed_tokens = 0 - request.spec_token_ids.clear() + if request.spec_token_ids: + request.spec_token_ids = [] request.num_preemptions += 1 if self.log_stats: request.record_event(EngineCoreEventType.PREEMPTED, timestamp)