[BugFix] Fix use of per-request seed with pipeline parallel (#6698)

This commit is contained in:
Nick Hill
2024-07-30 10:40:08 -07:00
committed by GitHub
parent f058403683
commit 5cf9254a9c
21 changed files with 222 additions and 137 deletions

View File

@@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
import torch
from vllm import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata, SequenceGroupState,
get_all_seq_ids)
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
@@ -16,6 +16,8 @@ SeqId = int
TargetSeqId = int
TokenId = int
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Implements a speculative scorer that uses batch expansion to get
@@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index])
# Use simpler sampling parameters apart from for final token
# (in particular don't do seeded sampling) since those sampled tokens
# aren't used.
# We don't replace the sampling_params in the greedy case because
# this also controls whether the probs get modified in the sampler
# (see use of _modify_greedy_probs_inplace there).
sampling_params = input_seq_group_metadata.sampling_params
non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
if sampling_params.temperature else sampling_params
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for token_ids in token_ids_to_score:
last_index = len(token_ids_to_score) - 1
for i, token_ids in enumerate(token_ids_to_score):
target_sampling_params = sampling_params if i == last_index \
else non_bonus_sampling_params
target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata(
input_seq_group_metadata,
input_seq_id,
next(target_seq_ids_iter),
token_ids,
sampling_params=target_sampling_params,
))
return target_seq_group_metadata_list
@staticmethod
def _create_single_target_seq_group_metadata(
self,
seq_group_metadata: SequenceGroupMetadata,
seq_id: SeqId,
target_seq_id: TargetSeqId,
token_ids: List[TokenId],
sampling_params: SamplingParams,
) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata.
@@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generator = torch.Generator(
device=seq_group_metadata.state.generator.device)
generator.set_state(seq_group_metadata.state.generator.get_state())
state = SequenceGroupState(generator=generator)
else:
state = None
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=seq_group_metadata.sampling_params,
sampling_params=sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
state=state,
)
def _split_scoring_output(

View File

@@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)
self.model_runner.pin_memory, generators)
model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states.

View File

@@ -38,9 +38,11 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
(input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)
self.model_runner.pin_memory, generators)
model_outputs = self.model_runner.model.generate_proposals(
input_ids=input_tokens,

View File

@@ -7,10 +7,9 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
class NGramWorker(NonLLMProposerWorkerBase):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implements prompt lookup decoding,

View File

@@ -213,6 +213,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
scorer_runner = getattr(self.scorer_worker, "model_runner", None)
self.generators = scorer_runner.get_generators(
) if scorer_runner else None
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
@@ -591,20 +594,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
# Sampler arguments
sampler_extra_kwargs = {}
if isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
# Get sequence group state
generators = []
for seq_group_metadata in seq_group_metadata_list:
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generators.append(seq_group_metadata.state.generator)
else:
generators.append(None)
sampler_extra_kwargs["generators"] = generators
sampler_extra_kwargs: Dict[str, Any] = {}
if self.generators and isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
sampler_extra_kwargs["seeded_seqs"] = {
idx: self.generators[sgm.request_id]
for idx, sgm in enumerate(seq_group_metadata_list)
if sgm.sampling_params.seed is not None
}
accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,