[Spec Decode] Make propose_draft_token_ids non-blocking for lower TTFT (#23041)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Callable, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -14,7 +14,7 @@ from vllm.executor.uniproc_executor import ( # noqa
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
UniProcExecutor as UniProcExecutorV0)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
|
||||
FailureCallback = Callable[[], None]
|
||||
|
||||
@@ -88,6 +88,10 @@ class Executor(ExecutorBase):
|
||||
args=(scheduler_output, ))
|
||||
return output[0]
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
output = self.collective_rpc("take_draft_token_ids")
|
||||
return output[0]
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 1
|
||||
|
||||
@@ -33,7 +33,7 @@ from vllm.utils import (decorate_logs, get_distributed_init_method,
|
||||
get_loopback_ip, get_mp_context, get_open_port,
|
||||
set_process_title)
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -191,6 +191,12 @@ class MultiprocExecutor(Executor):
|
||||
outputs, self.output_rank)
|
||||
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
# OPTIMIZATION: Get output only from a single worker (output_rank)
|
||||
outputs = self.collective_rpc("take_draft_token_ids",
|
||||
unique_reply_rank=self.output_rank)
|
||||
return outputs[0]
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
|
||||
Reference in New Issue
Block a user