[Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)

Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: Cade Daniel <cade@anyscale.com>
This commit is contained in:
Cody Yu
2024-05-16 00:53:51 -07:00
committed by GitHub
parent 30e754390c
commit 973617ae02
12 changed files with 295 additions and 180 deletions

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
@@ -17,11 +18,43 @@ from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__)
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
"""Helper method that is the entrypoint for Executors which use
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
assert "speculative_config" in kwargs
speculative_config = kwargs.get("speculative_config")
assert speculative_config is not None
target_worker = Worker(*args, **kwargs)
draft_worker_kwargs = kwargs.copy()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=speculative_config.draft_model_config,
parallel_config=speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size,
)
return spec_decode_worker
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""Worker which implements speculative decoding.
@@ -142,6 +175,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._configure_model_sampler_for_spec_decode()
def load_model(self, *args, **kwargs):
pass
def _configure_model_sampler_for_spec_decode(self):
"""Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
@@ -195,39 +231,97 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def _broadcast_control_flow_decision(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
disable_all_speculation: bool = False) -> Tuple[int, bool]:
"""Broadcast how many lookahead slots are scheduled for this step, and
whether all speculation is disabled, to all non-driver workers.
This is required as if the number of draft model runs changes
dynamically, the non-driver workers won't know unless we perform a
communication to inform then.
Returns the broadcasted num_lookahead_slots and disable_all_speculation.
"""
if self.rank == self._driver_rank:
assert execute_model_req is not None
broadcast_dict = dict(
num_lookahead_slots=execute_model_req.num_lookahead_slots,
disable_all_speculation=disable_all_speculation,
)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
else:
assert execute_model_req is None
broadcast_dict = broadcast_tensor_dict(src=self._driver_rank)
return (broadcast_dict["num_lookahead_slots"],
broadcast_dict["disable_all_speculation"])
@torch.inference_mode()
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch.
"""
assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding "
"requires non-None seq_group_metadata_list")
disable_all_speculation = False
if self.rank == self._driver_rank:
disable_all_speculation = self._should_disable_all_speculation(
execute_model_req)
(num_lookahead_slots,
disable_all_speculation) = self._broadcast_control_flow_decision(
execute_model_req, disable_all_speculation)
if self.rank == self._driver_rank:
assert execute_model_req is not None
assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding requires non-None seq_group_metadata_list"
)
self._maybe_disable_speculative_tokens(
disable_all_speculation,
execute_model_req.seq_group_metadata_list)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots)
else:
self._run_non_driver_rank(num_lookahead_slots)
return []
def _should_disable_all_speculation(
self, execute_model_req: ExecuteModelRequest) -> bool:
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
disable_all = (execute_model_req.running_queue_size >=
self.disable_by_batch_size)
if disable_all:
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0
disable_all_speculation = (execute_model_req.running_queue_size >=
self.disable_by_batch_size)
# If no spec tokens, call the proposer and scorer workers normally.
# This happens for prefill, or when the spec decode is disabled
# for this batch.
if execute_model_req.num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all)
return disable_all_speculation
return self._run_speculative_decoding_step(execute_model_req)
def _maybe_disable_speculative_tokens(
self, disable_all_speculation: bool,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
if not disable_all_speculation:
return
for seq_group_metadata in seq_group_metadata_list:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
@@ -252,10 +346,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sampler_output.logprobs = None
return [sampler_output]
def _run_non_driver_rank(self, num_lookahead_slots: int) -> None:
"""Run proposer and verifier model in non-driver workers. This is used
for both speculation cases (num_lookahead_slots>0) and non-speculation
cases (e.g. prefill).
"""
# In non-driver workers the input is None
execute_model_req = None
# Even if num_lookahead_slots is zero, we want to run the proposer model
# as it may have KV.
#
# We run the proposer once per lookahead slot. In the future we should
# delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model(execute_model_req)
self.scorer_worker.execute_model(execute_model_req)
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest,
num_lookahead_slots: int) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
@@ -264,6 +376,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
Returns a list of SamplerOutput, each containing a single token per
sequence.
"""
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
@@ -455,6 +568,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def device(self):
return self.scorer_worker.device
@property
def _driver_rank(self) -> int:
return 0
def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes.