[Speculative Decoding] Medusa Implementation with Top-1 proposer (#4978)

This commit is contained in:
Abhinav Goyal
2024-07-10 07:04:02 +05:30
committed by GitHub
parent d3a245138a
commit 2416b26e11
9 changed files with 587 additions and 4 deletions

View File

@@ -0,0 +1,127 @@
import weakref
from typing import List, Optional, Tuple
import torch
from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
class MedusaWorker(NonLLMProposerWorkerBase, Worker):
"""Worker for Medusa.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Lazy initialization list.
self._proposer: Top1Proposer
def init_device(self):
super().init_device()
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self):
pass
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For medusa worker, this indicator shall be False.
"""
self._raise_if_unsupported(execute_model_req)
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)
model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states.
hidden_states,
sampling_metadata=sampling_metadata)
return model_outputs, False
def _prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[List[int], List[int]]:
if not seq_group_metadata_list:
return [], []
seq_lens: List[int] = []
query_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
is_prompt = seq_group_metadata.is_prompt
for seq_data in seq_group_metadata.seq_data.values():
seq_data_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(
seq_data_len,
context_len + seq_group_metadata.token_chunk_size)
seq_lens.append(seq_len)
query_lens.append(seq_len - context_len)
else:
seq_lens.append(seq_data_len)
query_lens.append(1)
return seq_lens, query_lens
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_spec_proposals(execute_model_req)
def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MedusaWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MedusaWorker does not support beam search.")

View File

@@ -18,6 +18,7 @@ from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
@@ -129,6 +130,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"model_config"].hf_config.model_type == "mlp_speculator":
disable_bonus_tokens = False
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "medusa":
disable_bonus_tokens = False
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[