[Bugfix] Make spec. decode respect per-request seed. (#6034)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Thomas Parnell
2024-07-19 04:22:08 +02:00
committed by GitHub
parent b5672a112c
commit d4201e06d5
8 changed files with 293 additions and 46 deletions

View File

@@ -4,7 +4,8 @@ from typing import Iterator, List, Tuple
import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
SequenceGroupMetadata, SequenceGroupState,
get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
@@ -292,6 +293,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generator = torch.Generator(
device=seq_group_metadata.state.generator.device)
generator.set_state(seq_group_metadata.state.generator.get_state())
state = SequenceGroupState(generator=generator)
else:
state = None
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
@@ -302,6 +312,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
},
lora_request=None,
token_chunk_size=1,
state=state,
)
def _split_scoring_output(

View File

@@ -9,7 +9,7 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
@@ -521,11 +521,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
# Sampler arguments
sampler_extra_kwargs = {}
if isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
# Get sequence group state
generators = []
for seq_group_metadata in seq_group_metadata_list:
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generators.append(seq_group_metadata.state.generator)
else:
generators.append(None)
sampler_extra_kwargs["generators"] = generators
accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
**sampler_extra_kwargs,
)
# Append output tokens from non-speculative sequences to