[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)
This commit is contained in:
@@ -4,7 +4,8 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@@ -46,15 +47,15 @@ Timer = Callable[[], float]
|
||||
|
||||
|
||||
class AsyncMetricsCollector:
|
||||
"""Class which copies rejection sampler metrics from the device to CPU on a
|
||||
non-default Torch stream.
|
||||
"""Class which copies rejection/typical-acceptance sampler metrics
|
||||
from the device to CPU on a non-default Torch stream.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rejection_sampler: RejectionSampler,
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
timer: Optional[Timer] = None,
|
||||
collect_interval_s: float = 5.0):
|
||||
self._rejection_sampler = rejection_sampler
|
||||
self.spec_decode_sampler = spec_decode_sampler
|
||||
self._timer = time.time if timer is None else timer
|
||||
|
||||
self._rank: Optional[int] = None
|
||||
@@ -95,7 +96,7 @@ class AsyncMetricsCollector:
|
||||
return None
|
||||
|
||||
def _should_collect_rejsample_metrics(self, now: float) -> bool:
|
||||
"""Return whether or not this iteration should print rejection sampling
|
||||
"""Return whether or not this iteration should print sampling
|
||||
metrics.
|
||||
"""
|
||||
if self._rank != 0:
|
||||
@@ -107,8 +108,8 @@ class AsyncMetricsCollector:
|
||||
return True
|
||||
|
||||
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
|
||||
"""Copy rejection sampling metrics (number of accepted tokens, etc) to
|
||||
CPU asynchronously.
|
||||
"""Copy rejection/typical-acceptance sampling metrics
|
||||
(number of accepted tokens, etc) to CPU asynchronously.
|
||||
|
||||
Returns a CUDA event recording when the copy is complete.
|
||||
"""
|
||||
@@ -117,13 +118,14 @@ class AsyncMetricsCollector:
|
||||
|
||||
with torch.cuda.stream(self._copy_stream):
|
||||
self._aggregate_num_accepted_tokens.copy_(
|
||||
self._rejection_sampler.num_accepted_tokens, non_blocking=True)
|
||||
self.spec_decode_sampler.num_accepted_tokens,
|
||||
non_blocking=True)
|
||||
self._aggregate_num_emitted_tokens.copy_(
|
||||
self._rejection_sampler.num_emitted_tokens, non_blocking=True)
|
||||
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
||||
# Number of draft tokens is calculated on CPU, so no copy is
|
||||
# required.
|
||||
self._aggregate_num_draft_tokens = (
|
||||
self._rejection_sampler.num_draft_tokens)
|
||||
self.spec_decode_sampler.num_draft_tokens)
|
||||
|
||||
aggregate_metrics_ready = torch.cuda.Event()
|
||||
aggregate_metrics_ready.record(self._copy_stream)
|
||||
|
||||
@@ -7,6 +7,10 @@ from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||
get_all_seq_ids)
|
||||
@@ -56,7 +60,12 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
draft_worker_kwargs=draft_worker_kwargs,
|
||||
disable_by_batch_size=speculative_config.
|
||||
speculative_disable_by_batch_size,
|
||||
)
|
||||
draft_token_acceptance_method=speculative_config.
|
||||
draft_token_acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=speculative_config.
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
||||
typical_acceptance_sampler_posterior_alpha)
|
||||
|
||||
return spec_decode_worker
|
||||
|
||||
@@ -78,8 +87,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
welcome!).
|
||||
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
|
||||
future work.
|
||||
* Only lossless rejection sampling is supported. Contributions adding lossy
|
||||
verification routines are welcome (e.g. Medusa's typical acceptance).
|
||||
* All sequences in a batch must have the same proposal length, or zero. This
|
||||
can be improved by having per-sequence speculation in the future.
|
||||
* The scoring forward pass is done without an MQA kernel, which is
|
||||
@@ -95,6 +102,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
scorer_worker: Worker,
|
||||
draft_worker_kwargs: Dict[str, Any],
|
||||
disable_by_batch_size: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
ngram_prompt_lookup_max = (
|
||||
@@ -127,17 +137,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||
type(proposer_worker))
|
||||
|
||||
spec_decode_sampler: SpecDecodeBaseSampler = None
|
||||
if draft_token_acceptance_method == "rejection_sampler":
|
||||
spec_decode_sampler = RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens, )
|
||||
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
||||
spec_decode_sampler = TypicalAcceptanceSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
posterior_threshold=\
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||
)
|
||||
logger.info("Configuring SpecDecodeWorker with sampler=%s",
|
||||
type(spec_decode_sampler))
|
||||
|
||||
return SpecDecodeWorker(proposer_worker,
|
||||
scorer_worker,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
rejection_sampler=RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens))
|
||||
spec_decode_sampler=spec_decode_sampler)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proposer_worker: ProposerWorkerBase,
|
||||
scorer_worker: WorkerBase,
|
||||
rejection_sampler: RejectionSampler,
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
disable_by_batch_size: Optional[int] = None,
|
||||
):
|
||||
@@ -150,8 +173,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
scorer_worker: A worker that produces probabilities of speculative
|
||||
tokens according to some base model. Typically a vanilla vLLM
|
||||
Worker.
|
||||
rejection_sampler: A Torch module used to perform modified rejection
|
||||
sampling for speculative decoding.
|
||||
spec_decode_sampler: A Torch module used to perform acceptance
|
||||
sampling of the draft tokens in the verification step of
|
||||
speculative decoding. Currently we support two different
|
||||
types of sampler namely RejectionSampler and
|
||||
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
|
||||
instance of RejectionSampler or TypicalAcceptanceSampler.
|
||||
disable_by_batch_size: If the batch size is larger than this,
|
||||
disable speculative decoding for new incoming requests.
|
||||
metrics_collector: Helper class for collecting metrics; can be set
|
||||
@@ -160,15 +187,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.proposer_worker = proposer_worker
|
||||
self.scorer_worker = scorer_worker
|
||||
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
||||
self.rejection_sampler = rejection_sampler
|
||||
|
||||
self.spec_decode_sampler = spec_decode_sampler
|
||||
self._metrics = AsyncMetricsCollector(
|
||||
rejection_sampler
|
||||
self.spec_decode_sampler
|
||||
) if metrics_collector is None else metrics_collector
|
||||
|
||||
self.probs_dtype = self.rejection_sampler.probs_dtype
|
||||
self.token_id_dtype = self.rejection_sampler.token_id_dtype
|
||||
|
||||
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
||||
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
||||
# Lazy initiazliation.
|
||||
self.scorer: SpeculativeScorer
|
||||
|
||||
@@ -189,7 +213,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.proposer_worker.load_model()
|
||||
|
||||
self._metrics.init_gpu_tensors(self.rank)
|
||||
self.rejection_sampler.init_gpu_tensors(self.rank)
|
||||
self.spec_decode_sampler.init_gpu_tensors(self.rank)
|
||||
|
||||
self.scorer = BatchExpansionTop1Scorer(
|
||||
scorer_worker=self.scorer_worker,
|
||||
device=self.device,
|
||||
@@ -203,7 +228,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
def _configure_model_sampler_for_spec_decode(self):
|
||||
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
||||
to keep data on device without transferring to CPU and serializing,
|
||||
which significantly reduces overhead of rejection sampling.
|
||||
which significantly reduces overhead of sampling during verification.
|
||||
|
||||
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
|
||||
design is to have the "move to CPU and serialize" sampling decision be
|
||||
@@ -481,7 +506,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||
|
||||
accepted_token_ids = self.rejection_sampler(
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_probs=proposal_verifier_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
@@ -496,7 +521,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
accepted_token_ids = torch.cat(
|
||||
[accepted_token_ids, non_spec_token_ids])
|
||||
logprobs = proposal_scores.logprobs
|
||||
|
||||
# Rearrange so that results are in the order of the original seq group
|
||||
# metadata.
|
||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
Reference in New Issue
Block a user