[V1][Spec Decode] Ngram Spec Decode (#12193)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Lily Liu
2025-02-15 18:05:11 -08:00
committed by GitHub
parent 367cb8ce8c
commit 80f63a3966
21 changed files with 1023 additions and 82 deletions

View File

@@ -27,6 +27,7 @@ from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@@ -65,6 +66,7 @@ class EngineCore:
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config,
log_stats=self.log_stats,
)
@@ -84,6 +86,15 @@ class EngineCore:
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
# Setup speculative decode.
# TODO: find a better way to check if we are using ngram.
self.use_spec_decode = False
if self.scheduler.speculative_config:
assert self.scheduler.speculative_config.ngram_prompt_lookup_min \
, "Only ngram spec decode is supported in V1."
self.proposer = NgramProposer()
self.use_spec_decode = True
def _initialize_kv_caches(self,
vllm_config: VllmConfig) -> Tuple[int, int]:
start = time.time()
@@ -147,6 +158,9 @@ class EngineCore:
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
if self.use_spec_decode:
self.propose_tokens()
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
@@ -207,6 +221,23 @@ class EngineCore:
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
def propose_tokens(self):
assert self.scheduler.speculative_config is not None
for req in self.scheduler.running:
# Ignore requests that are doing chunked prefill.
if req.num_computed_tokens < req.num_tokens - 1:
continue
# Ignore requests that already have spec tokens.
if req.spec_token_ids:
continue
spec_tokens = self.proposer.propose(
req.all_token_ids,
self.scheduler.speculative_config.ngram_prompt_lookup_min,
self.scheduler.speculative_config.num_speculative_tokens,
)
if spec_tokens:
req.append_spec_token_ids(spec_tokens)
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()