[Spec Decode] Make propose_draft_token_ids non-blocking for lower TTFT (#23041)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -126,6 +126,7 @@ class EngineCore:
|
||||
> 1,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
self.use_spec_decode = vllm_config.speculative_config is not None
|
||||
|
||||
self.mm_input_cache_server = MultiModalInputCacheServer(
|
||||
vllm_config.model_config, MULTIMODAL_REGISTRY)
|
||||
@@ -294,6 +295,13 @@ class EngineCore:
|
||||
return (engine_core_outputs,
|
||||
scheduler_output.total_num_scheduled_tokens > 0)
|
||||
|
||||
def post_step(self, model_executed: bool) -> None:
|
||||
if self.use_spec_decode and model_executed:
|
||||
# Take the draft token ids.
|
||||
draft_token_ids = self.model_executor.take_draft_token_ids()
|
||||
if draft_token_ids is not None:
|
||||
self.scheduler.update_draft_token_ids(draft_token_ids)
|
||||
|
||||
def step_with_batch_queue(
|
||||
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
||||
"""Schedule and execute batches with the batch queue.
|
||||
@@ -746,6 +754,8 @@ class EngineCoreProc(EngineCore):
|
||||
# Put EngineCoreOutputs into the output queue.
|
||||
for output in (outputs.items() if outputs else ()):
|
||||
self.output_queue.put_nowait(output)
|
||||
# Post-step hook.
|
||||
self.post_step(model_executed)
|
||||
|
||||
return model_executed
|
||||
|
||||
|
||||
Reference in New Issue
Block a user