[Misc][Refactor] Introduce ExecuteModelData (#4540)

This commit is contained in:
Cody Yu
2024-05-03 17:47:07 -07:00
committed by GitHub
parent 344bf7cd2d
commit bc8ad68455
23 changed files with 355 additions and 511 deletions

View File

@@ -1,9 +1,10 @@
from itertools import chain, count
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Iterator, List, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
@@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
@@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
no speculation is produced for that sequence.
Args:
seq_group_metadata_list: The input sequence group metadata.
blocks_to_swap_in: This is passed to the worker during scoring.
blocks_to_swap_out: This is passed to the worker during scoring.
blocks_to_copy: This is passed to the worker during scoring.
k: The fixed proposal length.
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
@@ -80,28 +73,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
seq_group_metadata_list=target_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list, ))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(seq_group_metadata_list),
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=k,
k=execute_model_req.num_lookahead_slots,
)
return SpeculativeScores(

View File

@@ -1,10 +1,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
from vllm.sequence import SequenceGroupMetadata
from vllm.sequence import ExecuteModelRequest
@dataclass
@@ -58,11 +57,7 @@ class SpeculativeProposer(ABC):
@abstractmethod
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
raise NotImplementedError
@@ -72,11 +67,7 @@ class SpeculativeScorer(ABC):
@abstractmethod
def score_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
raise NotImplementedError

View File

@@ -1,9 +1,10 @@
import copy
from typing import Dict, List, Tuple
from typing import List, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
@@ -44,10 +45,7 @@ class MultiStepWorker(Worker):
@torch.inference_mode()
def sampler_output(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
@@ -57,26 +55,24 @@ class MultiStepWorker(Worker):
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
blocks_to_swap_out, blocks_to_copy)
self._raise_if_unsupported(execute_model_req)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list = self._shallow_copy_inputs(
seq_group_metadata_list)
execute_model_req.seq_group_metadata_list)
copied_execute_model_req = execute_model_req.clone(
copied_seq_group_metadata_list)
# Assert enough KV space for sample_len tokens per sequence.
self._assert_enough_kv_space(seq_group_metadata_list, sample_len)
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
sample_len)
# Run model sample_len times.
model_outputs = []
for _ in range(sample_len):
model_output = super().execute_model(
seq_group_metadata_list=copied_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
execute_model_req=copied_execute_model_req)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
@@ -89,23 +85,13 @@ class MultiStepWorker(Worker):
def get_spec_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
max_proposal_len,
)
return self._proposer.get_proposals(execute_model_req)
def _append_new_tokens(
self, model_output: SamplerOutput,
@@ -196,20 +182,22 @@ class MultiStepWorker(Worker):
def _raise_if_unsupported(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
execute_model_req: ExecuteModelRequest,
) -> None:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MultiStepWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list):
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")

View File

@@ -1,8 +1,8 @@
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
@@ -46,13 +46,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
# NGram don't need gpu sampler
pass
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
) -> None:
def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
"""NGram doesn't depend on model execution, just pass this function"""
pass
@@ -71,10 +65,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
def sampler_output(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]:
"""NGram match algo to pick proposal candidate. Returns the list of
@@ -83,16 +74,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
For ngram worker, we already done needed transposed internal, so the
indicator pass to sampler_output_to_torch shall be False.
"""
self._raise_if_unsupported(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)
self._raise_if_unsupported(execute_model_req)
arr = []
has_spec_out = False
for seq_group_metadata in seq_group_metadata_list:
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
seq_data = next(iter(seq_group_metadata.seq_data.values()))
input_ids = torch.as_tensor(seq_data.get_token_ids(),
@@ -135,17 +121,19 @@ class NGramWorker(LoraNotSupportedWorkerBase):
indices = token_ids.unsqueeze(2)
token_probs = torch.zeros(
(len(seq_group_metadata_list), sample_len, self.vocab_size),
(len(execute_model_req.seq_group_metadata_list), sample_len,
self.vocab_size),
dtype=torch.float32,
device=self.device,
)
token_probs.scatter_(2, indices, 1)
token_logprobs = torch.zeros(
(len(seq_group_metadata_list), sample_len, self.vocab_size),
(len(execute_model_req.seq_group_metadata_list), sample_len,
self.vocab_size),
dtype=torch.float32,
device=self.device,
)
for i in range(len(seq_group_metadata_list)):
for i in range(len(execute_model_req.seq_group_metadata_list)):
outputs.append(
SamplerOutput(
outputs=None,
@@ -157,40 +145,32 @@ class NGramWorker(LoraNotSupportedWorkerBase):
def get_spec_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
max_proposal_len,
)
return self._proposer.get_proposals(execute_model_req)
def _raise_if_unsupported(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
execute_model_req: ExecuteModelRequest,
) -> None:
"""NGramWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"NGramWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list):
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"NGramWorker does not support beam search.")

View File

@@ -1,11 +1,12 @@
from functools import cached_property
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
@@ -189,69 +190,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch.
"""
assert seq_group_metadata_list is not None, (
assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding "
"requires non-None seq_group_metadata_list")
#logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
# num_lookahead_slots)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
return self._run_no_spec(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
if execute_model_req.num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req)
return self._run_speculative_decoding_step(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
k=num_lookahead_slots,
)
return self._run_speculative_decoding_step(execute_model_req)
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
) -> List[SamplerOutput]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the
two.
"""
#logger.info("run proposer worker no spec")
self.proposer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
self.proposer_worker.execute_model(execute_model_req)
#logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
@@ -264,13 +233,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
) -> List[SamplerOutput]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
@@ -282,33 +246,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
#logger.info("get spec proposals")
# Generate proposals using draft worker.
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
#logger.info("score proposals")
proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
k,
execute_model_req,
proposals,
)
#logger.info("verify proposals")
accepted_token_ids, target_logprobs = self._verify_tokens(
seq_group_metadata_list, proposal_scores, proposals, k)
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
#logger.info("create output list")
return self._create_output_sampler_list(
seq_group_metadata_list,
execute_model_req.seq_group_metadata_list,
accepted_token_ids,
target_logprobs=target_logprobs,
k=k)
k=execute_model_req.num_lookahead_slots)
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(

View File

@@ -1,8 +1,9 @@
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch
@@ -40,17 +41,15 @@ class Top1Proposer(SpeculativeProposer):
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
proposal_len: int,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
proposal_len = execute_model_req.num_lookahead_slots
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
@@ -66,11 +65,12 @@ class Top1Proposer(SpeculativeProposer):
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
maybe_sampler_output, transposed = self._worker.sampler_output(
nonzero_execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=nonzero_proposal_len_seqs,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=proposal_len,
)
maybe_sampler_output, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
)
else: