[Bugfix] Make spec. decode respect per-request seed. (#6034)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -4,7 +4,8 @@ from typing import Iterator, List, Tuple
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
SequenceGroupMetadata, SequenceGroupState,
|
||||
get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
||||
@@ -292,6 +293,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
for data in new_seq_data_dict.values():
|
||||
data.update_num_computed_tokens(data.get_len() - 1)
|
||||
|
||||
if (seq_group_metadata.state is not None
|
||||
and seq_group_metadata.state.generator is not None):
|
||||
generator = torch.Generator(
|
||||
device=seq_group_metadata.state.generator.device)
|
||||
generator.set_state(seq_group_metadata.state.generator.get_state())
|
||||
state = SequenceGroupState(generator=generator)
|
||||
else:
|
||||
state = None
|
||||
|
||||
return SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
@@ -302,6 +312,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
},
|
||||
lora_request=None,
|
||||
token_chunk_size=1,
|
||||
state=state,
|
||||
)
|
||||
|
||||
def _split_scoring_output(
|
||||
|
||||
@@ -9,7 +9,7 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
@@ -521,11 +521,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||
|
||||
# Sampler arguments
|
||||
sampler_extra_kwargs = {}
|
||||
if isinstance(self.spec_decode_sampler,
|
||||
SpecDecodeStochasticBaseSampler):
|
||||
|
||||
# Get sequence group state
|
||||
generators = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
if (seq_group_metadata.state is not None
|
||||
and seq_group_metadata.state.generator is not None):
|
||||
generators.append(seq_group_metadata.state.generator)
|
||||
else:
|
||||
generators.append(None)
|
||||
|
||||
sampler_extra_kwargs["generators"] = generators
|
||||
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_probs=proposal_verifier_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
draft_token_ids=proposal_token_ids,
|
||||
**sampler_extra_kwargs,
|
||||
)
|
||||
|
||||
# Append output tokens from non-speculative sequences to
|
||||
|
||||
Reference in New Issue
Block a user