[Spec Decode] Make propose_draft_token_ids non-blocking for lower TTFT (#23041)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-08-18 17:20:38 -07:00
committed by GitHub
parent 0dd3f4f5ab
commit c9b38be8aa
13 changed files with 100 additions and 64 deletions

View File

@@ -33,7 +33,7 @@ from vllm.utils import (decorate_logs, get_distributed_init_method,
get_loopback_ip, get_mp_context, get_open_port,
set_process_title)
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@@ -191,6 +191,12 @@ class MultiprocExecutor(Executor):
outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
# OPTIMIZATION: Get output only from a single worker (output_rank)
outputs = self.collective_rpc("take_draft_token_ids",
unique_reply_rank=self.output_rank)
return outputs[0]
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,