[V1][Spec decode] Move drafter to model runner (#13363)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-02-17 15:40:12 -08:00
committed by GitHub
parent 6ac485a953
commit cd4a72a28d
9 changed files with 84 additions and 57 deletions

View File

@@ -27,7 +27,6 @@ from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@@ -86,15 +85,6 @@ class EngineCore:
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
# Setup speculative decode.
# TODO: find a better way to check if we are using ngram.
self.use_spec_decode = False
if self.scheduler.speculative_config:
assert self.scheduler.speculative_config.ngram_prompt_lookup_min \
, "Only ngram spec decode is supported in V1."
self.proposer = NgramProposer()
self.use_spec_decode = True
def _initialize_kv_caches(self,
vllm_config: VllmConfig) -> Tuple[int, int]:
start = time.time()
@@ -158,9 +148,6 @@ class EngineCore:
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
if self.use_spec_decode:
self.propose_tokens()
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
@@ -221,23 +208,6 @@ class EngineCore:
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
def propose_tokens(self):
assert self.scheduler.speculative_config is not None
for req in self.scheduler.running:
# Ignore requests that are doing chunked prefill.
if req.num_computed_tokens < req.num_tokens - 1:
continue
# Ignore requests that already have spec tokens.
if req.spec_token_ids:
continue
spec_tokens = self.proposer.propose(
req.all_token_ids,
self.scheduler.speculative_config.ngram_prompt_lookup_min,
self.scheduler.speculative_config.num_speculative_tokens,
)
if spec_tokens:
req.append_spec_token_ids(spec_tokens)
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()