[Misc][Refactor] Introduce ExecuteModelData (#4540)
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from itertools import chain, count
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Iterator, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
@@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||
def score_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Score the proposed tokens via the scorer model.
|
||||
@@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
no speculation is produced for that sequence.
|
||||
|
||||
Args:
|
||||
seq_group_metadata_list: The input sequence group metadata.
|
||||
blocks_to_swap_in: This is passed to the worker during scoring.
|
||||
blocks_to_swap_out: This is passed to the worker during scoring.
|
||||
blocks_to_copy: This is passed to the worker during scoring.
|
||||
k: The fixed proposal length.
|
||||
execute_model_req: The execution request.
|
||||
proposals: The speculative proposals to score.
|
||||
Returns:
|
||||
SpeculativeScores: The scores of each speculative token, along with
|
||||
@@ -80,28 +73,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list, ))
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
all_tokens, all_probs, spec_logprobs = self._contract_batch(
|
||||
contracted_bs=len(seq_group_metadata_list),
|
||||
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=k,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
|
||||
return SpeculativeScores(
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -58,11 +57,7 @@ class SpeculativeProposer(ABC):
|
||||
@abstractmethod
|
||||
def get_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -72,11 +67,7 @@ class SpeculativeScorer(ABC):
|
||||
@abstractmethod
|
||||
def score_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
@@ -44,10 +45,7 @@ class MultiStepWorker(Worker):
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
@@ -57,26 +55,24 @@ class MultiStepWorker(Worker):
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
||||
blocks_to_swap_out, blocks_to_copy)
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
# Shallow copy input data so modifications (such as appending tokens)
|
||||
# do not cause side-effects.
|
||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
seq_group_metadata_list)
|
||||
execute_model_req.seq_group_metadata_list)
|
||||
copied_execute_model_req = execute_model_req.clone(
|
||||
copied_seq_group_metadata_list)
|
||||
|
||||
# Assert enough KV space for sample_len tokens per sequence.
|
||||
self._assert_enough_kv_space(seq_group_metadata_list, sample_len)
|
||||
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
|
||||
sample_len)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs = []
|
||||
for _ in range(sample_len):
|
||||
model_output = super().execute_model(
|
||||
seq_group_metadata_list=copied_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
execute_model_req=copied_execute_model_req)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
@@ -89,23 +85,13 @@ class MultiStepWorker(Worker):
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
max_proposal_len,
|
||||
)
|
||||
return self._proposer.get_proposals(execute_model_req)
|
||||
|
||||
def _append_new_tokens(
|
||||
self, model_output: SamplerOutput,
|
||||
@@ -196,20 +182,22 @@ class MultiStepWorker(Worker):
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""MultiStepWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
@@ -46,13 +46,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
# NGram don't need gpu sampler
|
||||
pass
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
) -> None:
|
||||
def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
|
||||
"""NGram doesn't depend on model execution, just pass this function"""
|
||||
pass
|
||||
|
||||
@@ -71,10 +65,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||
@@ -83,16 +74,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
For ngram worker, we already done needed transposed internal, so the
|
||||
indicator pass to sampler_output_to_torch shall be False.
|
||||
"""
|
||||
self._raise_if_unsupported(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
)
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
arr = []
|
||||
has_spec_out = False
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
|
||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||
@@ -135,17 +121,19 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
indices = token_ids.unsqueeze(2)
|
||||
|
||||
token_probs = torch.zeros(
|
||||
(len(seq_group_metadata_list), sample_len, self.vocab_size),
|
||||
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
||||
self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
token_probs.scatter_(2, indices, 1)
|
||||
token_logprobs = torch.zeros(
|
||||
(len(seq_group_metadata_list), sample_len, self.vocab_size),
|
||||
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
||||
self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(len(seq_group_metadata_list)):
|
||||
for i in range(len(execute_model_req.seq_group_metadata_list)):
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
@@ -157,40 +145,32 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
max_proposal_len,
|
||||
)
|
||||
return self._proposer.get_proposals(execute_model_req)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""NGramWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support beam search.")
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
@@ -189,69 +190,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Perform speculative decoding on the input batch.
|
||||
"""
|
||||
|
||||
assert seq_group_metadata_list is not None, (
|
||||
assert execute_model_req.seq_group_metadata_list is not None, (
|
||||
"speculative decoding "
|
||||
"requires non-None seq_group_metadata_list")
|
||||
|
||||
#logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
|
||||
# num_lookahead_slots)
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# Used for prefill.
|
||||
if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
if execute_model_req.num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(execute_model_req)
|
||||
|
||||
return self._run_speculative_decoding_step(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
k=num_lookahead_slots,
|
||||
)
|
||||
return self._run_speculative_decoding_step(execute_model_req)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Run a prefill step, without any speculation. The input is sent to the
|
||||
proposer and scorer model so that the KV cache is consistent between the
|
||||
two.
|
||||
"""
|
||||
#logger.info("run proposer worker no spec")
|
||||
|
||||
self.proposer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
#logger.info("run target worker no spec")
|
||||
sampler_output = self.scorer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
|
||||
@@ -264,13 +233,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||
def _run_speculative_decoding_step(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Execute a single step of speculative decoding.
|
||||
|
||||
This invokes the proposer worker to get k speculative tokens for each
|
||||
@@ -282,33 +246,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
#logger.info("get spec proposals")
|
||||
# Generate proposals using draft worker.
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, k)
|
||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||
|
||||
#logger.info("score proposals")
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
k,
|
||||
execute_model_req,
|
||||
proposals,
|
||||
)
|
||||
|
||||
#logger.info("verify proposals")
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
seq_group_metadata_list, proposal_scores, proposals, k)
|
||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||
proposals, execute_model_req.num_lookahead_slots)
|
||||
|
||||
#logger.info("create output list")
|
||||
return self._create_output_sampler_list(
|
||||
seq_group_metadata_list,
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
accepted_token_ids,
|
||||
target_logprobs=target_logprobs,
|
||||
k=k)
|
||||
k=execute_model_req.num_lookahead_slots)
|
||||
|
||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||
def _verify_tokens(
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
@@ -40,17 +41,15 @@ class Top1Proposer(SpeculativeProposer):
|
||||
|
||||
def get_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
proposal_len: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
proposal_len = execute_model_req.num_lookahead_slots
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(
|
||||
@@ -66,11 +65,12 @@ class Top1Proposer(SpeculativeProposer):
|
||||
# token_ids is like [batch] format in proposal_len size list,
|
||||
# while if it is false, the format would be [proposal_len]
|
||||
# in batch size list
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
nonzero_execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=proposal_len,
|
||||
)
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user