[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)

This commit is contained in:
sroy745
2024-07-01 00:33:05 -07:00
committed by GitHub
parent 614aa51203
commit 80ca1e6a3a
14 changed files with 480 additions and 208 deletions

View File

@@ -4,7 +4,8 @@ from typing import Callable, Optional
import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.utils import is_pin_memory_available
@@ -46,15 +47,15 @@ Timer = Callable[[], float]
class AsyncMetricsCollector:
"""Class which copies rejection sampler metrics from the device to CPU on a
non-default Torch stream.
"""Class which copies rejection/typical-acceptance sampler metrics
from the device to CPU on a non-default Torch stream.
"""
def __init__(self,
rejection_sampler: RejectionSampler,
spec_decode_sampler: SpecDecodeBaseSampler,
timer: Optional[Timer] = None,
collect_interval_s: float = 5.0):
self._rejection_sampler = rejection_sampler
self.spec_decode_sampler = spec_decode_sampler
self._timer = time.time if timer is None else timer
self._rank: Optional[int] = None
@@ -95,7 +96,7 @@ class AsyncMetricsCollector:
return None
def _should_collect_rejsample_metrics(self, now: float) -> bool:
"""Return whether or not this iteration should print rejection sampling
"""Return whether or not this iteration should print sampling
metrics.
"""
if self._rank != 0:
@@ -107,8 +108,8 @@ class AsyncMetricsCollector:
return True
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection sampling metrics (number of accepted tokens, etc) to
CPU asynchronously.
"""Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a CUDA event recording when the copy is complete.
"""
@@ -117,13 +118,14 @@ class AsyncMetricsCollector:
with torch.cuda.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self._rejection_sampler.num_accepted_tokens, non_blocking=True)
self.spec_decode_sampler.num_accepted_tokens,
non_blocking=True)
self._aggregate_num_emitted_tokens.copy_(
self._rejection_sampler.num_emitted_tokens, non_blocking=True)
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self._aggregate_num_draft_tokens = (
self._rejection_sampler.num_draft_tokens)
self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.cuda.Event()
aggregate_metrics_ready.record(self._copy_stream)

View File

@@ -7,6 +7,10 @@ from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SamplerOutput, SequenceGroupMetadata,
get_all_seq_ids)
@@ -56,7 +60,12 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size,
)
draft_token_acceptance_method=speculative_config.
draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=speculative_config.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha)
return spec_decode_worker
@@ -78,8 +87,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
welcome!).
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
future work.
* Only lossless rejection sampling is supported. Contributions adding lossy
verification routines are welcome (e.g. Medusa's typical acceptance).
* All sequences in a batch must have the same proposal length, or zero. This
can be improved by having per-sequence speculation in the future.
* The scoring forward pass is done without an MQA kernel, which is
@@ -95,6 +102,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
scorer_worker: Worker,
draft_worker_kwargs: Dict[str, Any],
disable_by_batch_size: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
) -> "SpecDecodeWorker":
ngram_prompt_lookup_max = (
@@ -127,17 +137,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker))
spec_decode_sampler: SpecDecodeBaseSampler = None
if draft_token_acceptance_method == "rejection_sampler":
spec_decode_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens, )
elif draft_token_acceptance_method == "typical_acceptance_sampler":
spec_decode_sampler = TypicalAcceptanceSampler(
disable_bonus_tokens=disable_bonus_tokens,
posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
)
logger.info("Configuring SpecDecodeWorker with sampler=%s",
type(spec_decode_sampler))
return SpecDecodeWorker(proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens))
spec_decode_sampler=spec_decode_sampler)
def __init__(
self,
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler,
spec_decode_sampler: SpecDecodeBaseSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
):
@@ -150,8 +173,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
scorer_worker: A worker that produces probabilities of speculative
tokens according to some base model. Typically a vanilla vLLM
Worker.
rejection_sampler: A Torch module used to perform modified rejection
sampling for speculative decoding.
spec_decode_sampler: A Torch module used to perform acceptance
sampling of the draft tokens in the verification step of
speculative decoding. Currently we support two different
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
@@ -160,15 +187,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.rejection_sampler = rejection_sampler
self.spec_decode_sampler = spec_decode_sampler
self._metrics = AsyncMetricsCollector(
rejection_sampler
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector
self.probs_dtype = self.rejection_sampler.probs_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype
self.probs_dtype = self.spec_decode_sampler.probs_dtype
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
# Lazy initiazliation.
self.scorer: SpeculativeScorer
@@ -189,7 +213,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker.load_model()
self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer(
scorer_worker=self.scorer_worker,
device=self.device,
@@ -203,7 +228,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def _configure_model_sampler_for_spec_decode(self):
"""Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
which significantly reduces overhead of rejection sampling.
which significantly reduces overhead of sampling during verification.
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
design is to have the "move to CPU and serialize" sampling decision be
@@ -481,7 +506,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
accepted_token_ids = self.rejection_sampler(
accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
@@ -496,7 +521,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
logprobs = proposal_scores.logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone()