[Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)
Co-authored-by: Cade Daniel <edacih@gmail.com> Co-authored-by: Cade Daniel <cade@anyscale.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
@@ -17,11 +18,43 @@ from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
get_all_num_logprobs, get_all_seq_ids,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
"""Helper method that is the entrypoint for Executors which use
|
||||
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
|
||||
"""
|
||||
assert "speculative_config" in kwargs
|
||||
speculative_config = kwargs.get("speculative_config")
|
||||
assert speculative_config is not None
|
||||
|
||||
target_worker = Worker(*args, **kwargs)
|
||||
|
||||
draft_worker_kwargs = kwargs.copy()
|
||||
# Override draft-model specific worker args.
|
||||
draft_worker_kwargs.update(
|
||||
model_config=speculative_config.draft_model_config,
|
||||
parallel_config=speculative_config.draft_parallel_config,
|
||||
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
|
||||
# TODO allow draft-model specific load config.
|
||||
#load_config=load_config,
|
||||
)
|
||||
|
||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||
scorer_worker=target_worker,
|
||||
draft_worker_kwargs=draft_worker_kwargs,
|
||||
disable_by_batch_size=speculative_config.
|
||||
speculative_disable_by_batch_size,
|
||||
)
|
||||
|
||||
return spec_decode_worker
|
||||
|
||||
|
||||
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"""Worker which implements speculative decoding.
|
||||
|
||||
@@ -142,6 +175,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
self._configure_model_sampler_for_spec_decode()
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _configure_model_sampler_for_spec_decode(self):
|
||||
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
||||
to keep data on device without transferring to CPU and serializing,
|
||||
@@ -195,39 +231,97 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
def _broadcast_control_flow_decision(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
disable_all_speculation: bool = False) -> Tuple[int, bool]:
|
||||
"""Broadcast how many lookahead slots are scheduled for this step, and
|
||||
whether all speculation is disabled, to all non-driver workers.
|
||||
|
||||
This is required as if the number of draft model runs changes
|
||||
dynamically, the non-driver workers won't know unless we perform a
|
||||
communication to inform then.
|
||||
|
||||
Returns the broadcasted num_lookahead_slots and disable_all_speculation.
|
||||
"""
|
||||
|
||||
if self.rank == self._driver_rank:
|
||||
assert execute_model_req is not None
|
||||
|
||||
broadcast_dict = dict(
|
||||
num_lookahead_slots=execute_model_req.num_lookahead_slots,
|
||||
disable_all_speculation=disable_all_speculation,
|
||||
)
|
||||
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
||||
else:
|
||||
assert execute_model_req is None
|
||||
broadcast_dict = broadcast_tensor_dict(src=self._driver_rank)
|
||||
|
||||
return (broadcast_dict["num_lookahead_slots"],
|
||||
broadcast_dict["disable_all_speculation"])
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
"""Perform speculative decoding on the input batch.
|
||||
"""
|
||||
|
||||
assert execute_model_req.seq_group_metadata_list is not None, (
|
||||
"speculative decoding "
|
||||
"requires non-None seq_group_metadata_list")
|
||||
disable_all_speculation = False
|
||||
if self.rank == self._driver_rank:
|
||||
disable_all_speculation = self._should_disable_all_speculation(
|
||||
execute_model_req)
|
||||
|
||||
(num_lookahead_slots,
|
||||
disable_all_speculation) = self._broadcast_control_flow_decision(
|
||||
execute_model_req, disable_all_speculation)
|
||||
|
||||
if self.rank == self._driver_rank:
|
||||
assert execute_model_req is not None
|
||||
assert execute_model_req.seq_group_metadata_list is not None, (
|
||||
"speculative decoding requires non-None seq_group_metadata_list"
|
||||
)
|
||||
|
||||
self._maybe_disable_speculative_tokens(
|
||||
disable_all_speculation,
|
||||
execute_model_req.seq_group_metadata_list)
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# Used for prefill.
|
||||
if num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(execute_model_req,
|
||||
skip_proposer=disable_all_speculation)
|
||||
|
||||
return self._run_speculative_decoding_step(execute_model_req,
|
||||
num_lookahead_slots)
|
||||
else:
|
||||
self._run_non_driver_rank(num_lookahead_slots)
|
||||
return []
|
||||
|
||||
def _should_disable_all_speculation(
|
||||
self, execute_model_req: ExecuteModelRequest) -> bool:
|
||||
# When the batch size is too large, disable speculative decoding
|
||||
# to stop trading off throughput for latency.
|
||||
disable_all = (execute_model_req.running_queue_size >=
|
||||
self.disable_by_batch_size)
|
||||
if disable_all:
|
||||
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
|
||||
# Once num_speculative_tokens is set to 0, the spec decode
|
||||
# of this request will be disabled forever.
|
||||
# TODO(comaniac): We currently store spec decoding specific
|
||||
# state in the global data structure, but we should maintain
|
||||
# this state within spec decode worker.
|
||||
seq_group_metadata.num_speculative_tokens = 0
|
||||
disable_all_speculation = (execute_model_req.running_queue_size >=
|
||||
self.disable_by_batch_size)
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# This happens for prefill, or when the spec decode is disabled
|
||||
# for this batch.
|
||||
if execute_model_req.num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(execute_model_req,
|
||||
skip_proposer=disable_all)
|
||||
return disable_all_speculation
|
||||
|
||||
return self._run_speculative_decoding_step(execute_model_req)
|
||||
def _maybe_disable_speculative_tokens(
|
||||
self, disable_all_speculation: bool,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||
if not disable_all_speculation:
|
||||
return
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Once num_speculative_tokens is set to 0, the spec decode
|
||||
# of this request will be disabled forever.
|
||||
# TODO(comaniac): We currently store spec decoding specific
|
||||
# state in the global data structure, but we should maintain
|
||||
# this state within spec decode worker.
|
||||
seq_group_metadata.num_speculative_tokens = 0
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||
@@ -252,10 +346,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
sampler_output.logprobs = None
|
||||
return [sampler_output]
|
||||
|
||||
def _run_non_driver_rank(self, num_lookahead_slots: int) -> None:
|
||||
"""Run proposer and verifier model in non-driver workers. This is used
|
||||
for both speculation cases (num_lookahead_slots>0) and non-speculation
|
||||
cases (e.g. prefill).
|
||||
"""
|
||||
# In non-driver workers the input is None
|
||||
execute_model_req = None
|
||||
|
||||
# Even if num_lookahead_slots is zero, we want to run the proposer model
|
||||
# as it may have KV.
|
||||
#
|
||||
# We run the proposer once per lookahead slot. In the future we should
|
||||
# delegate how many times it runs to the proposer.
|
||||
for _ in range(max(num_lookahead_slots, 1)):
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
self.scorer_worker.execute_model(execute_model_req)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||
def _run_speculative_decoding_step(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
self, execute_model_req: ExecuteModelRequest,
|
||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
||||
"""Execute a single step of speculative decoding.
|
||||
|
||||
This invokes the proposer worker to get k speculative tokens for each
|
||||
@@ -264,6 +376,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
Returns a list of SamplerOutput, each containing a single token per
|
||||
sequence.
|
||||
"""
|
||||
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
|
||||
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||
@@ -455,6 +568,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
def device(self):
|
||||
return self.scorer_worker.device
|
||||
|
||||
@property
|
||||
def _driver_rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def get_cache_block_size_bytes(self):
|
||||
"""Return the size of a cache block in bytes.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user