[Model] MLPSpeculator speculative decoding support (#4947)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
This commit is contained in:
committed by
GitHub
parent
6c5b7af152
commit
b12518d3cf
@@ -4,11 +4,10 @@ from typing import Iterator, List, Tuple
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
sampler_output_to_torch,
|
||||
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
@@ -98,6 +97,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=spec_logprobs,
|
||||
hidden_states=target_sampler_output.hidden_states,
|
||||
)
|
||||
|
||||
def _expand_batch(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -46,6 +47,9 @@ class SpeculativeScores:
|
||||
# tokens and also non-speculative normal decoding.
|
||||
token_ids: torch.Tensor
|
||||
|
||||
# Optional last hidden states from the scoring model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeScores("
|
||||
f"probs={self.probs.shape}, "
|
||||
|
||||
87
vllm/spec_decode/mlp_speculator_worker.py
Normal file
87
vllm/spec_decode/mlp_speculator_worker.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.worker.model_runner import ModelInput
|
||||
|
||||
|
||||
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||
"""Worker for MLPSpeculator models.
|
||||
|
||||
Not currently compatible with LoRA or chunked prefill.
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
of whether torch tensor in sampler output need to be transposed in
|
||||
latter sampler_output_to_torch logic.
|
||||
|
||||
For mlp spec worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
(input_tokens, seq_lens,
|
||||
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory)
|
||||
|
||||
model_outputs = self.model_runner.model.generate_proposals(
|
||||
input_ids=input_tokens,
|
||||
previous_hidden_states=execute_model_req.previous_hidden_states.
|
||||
hidden_states,
|
||||
num_predict_tokens=sample_len,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
assert len(model_outputs) == sample_len
|
||||
|
||||
return model_outputs, True
|
||||
|
||||
def _prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, List[int], List[int]]:
|
||||
if not seq_group_metadata_list:
|
||||
return ModelInput.empty(self.device)
|
||||
|
||||
input_tokens: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seq_data_len = seq_data.get_len()
|
||||
if is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(
|
||||
seq_data_len,
|
||||
context_len + seq_group_metadata.token_chunk_size)
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
seq_lens.append(seq_len)
|
||||
input_tokens.extend(tokens)
|
||||
query_lens.append(seq_len - context_len)
|
||||
else:
|
||||
seq_lens.append(seq_data_len)
|
||||
input_tokens.append(seq_data.get_last_token_id())
|
||||
query_lens.append(1)
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
return input_tokens_tensor, seq_lens, query_lens
|
||||
@@ -8,16 +8,18 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
SamplerOutput, SequenceGroupMetadata)
|
||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||
get_all_seq_ids)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
get_all_num_logprobs, get_all_seq_ids,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker import Worker
|
||||
@@ -104,6 +106,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
elif draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "mlp_speculator":
|
||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||
disable_bonus_tokens = False
|
||||
else:
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
|
||||
@@ -155,6 +161,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# Lazy initiazliation.
|
||||
self.scorer: SpeculativeScorer
|
||||
|
||||
# Hidden states from target model to pass to proposer
|
||||
# in the subsequent step.
|
||||
self.previous_hidden_states: Optional[HiddenStates] = None
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
"""
|
||||
@@ -337,6 +347,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
|
||||
# Store hidden states from target model execution.
|
||||
hidden_states = sampler_output.hidden_states
|
||||
if hidden_states is not None:
|
||||
if self.previous_hidden_states is None:
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
execute_model_req.seq_group_metadata_list, hidden_states)
|
||||
else:
|
||||
self.previous_hidden_states.update(
|
||||
execute_model_req.seq_group_metadata_list, hidden_states)
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.probs = None
|
||||
@@ -383,6 +403,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"""
|
||||
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
|
||||
|
||||
# Pass last hidden states from target model to proposer
|
||||
execute_model_req.previous_hidden_states = self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||
|
||||
@@ -466,6 +490,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# metadata.
|
||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
hidden_states = proposal_scores.hidden_states
|
||||
if hidden_states is not None:
|
||||
# Contract hidden states based on accepted tokens
|
||||
hs_size = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
|
||||
hs_size)
|
||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
|
||||
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
|
||||
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
||||
# Store hidden states from target model for subsequent decode step
|
||||
self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
|
||||
hidden_states)
|
||||
|
||||
return accepted_token_ids, logprobs
|
||||
|
||||
def _create_output_sampler_list(
|
||||
|
||||
@@ -65,9 +65,13 @@ class Top1Proposer(SpeculativeProposer):
|
||||
# token_ids is like [batch] format in proposal_len size list,
|
||||
# while if it is false, the format would be [proposal_len]
|
||||
# in batch size list
|
||||
hidden_states = execute_model_req.previous_hidden_states
|
||||
if hidden_states is not None:
|
||||
hidden_states.prune(nonzero_proposal_len_seqs)
|
||||
nonzero_execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
num_lookahead_slots=proposal_len,
|
||||
previous_hidden_states=hidden_states,
|
||||
)
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
|
||||
@@ -10,14 +10,6 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SeqId = int
|
||||
|
||||
|
||||
def get_all_seq_ids(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||
sequence ids.
|
||||
"""
|
||||
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
||||
|
||||
|
||||
def get_all_num_logprobs(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
|
||||
|
||||
Reference in New Issue
Block a user