[BugFix] Fix use of per-request seed with pipeline parallel (#6698)
This commit is contained in:
@@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata, SequenceGroupState,
|
||||
get_all_seq_ids)
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
||||
@@ -16,6 +16,8 @@ SeqId = int
|
||||
TargetSeqId = int
|
||||
TokenId = int
|
||||
|
||||
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
|
||||
|
||||
|
||||
class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
"""Implements a speculative scorer that uses batch expansion to get
|
||||
@@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
token_ids_to_score = self._get_token_ids_to_score(
|
||||
proposal_token_ids[batch_index])
|
||||
|
||||
# Use simpler sampling parameters apart from for final token
|
||||
# (in particular don't do seeded sampling) since those sampled tokens
|
||||
# aren't used.
|
||||
# We don't replace the sampling_params in the greedy case because
|
||||
# this also controls whether the probs get modified in the sampler
|
||||
# (see use of _modify_greedy_probs_inplace there).
|
||||
sampling_params = input_seq_group_metadata.sampling_params
|
||||
non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
|
||||
if sampling_params.temperature else sampling_params
|
||||
|
||||
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for token_ids in token_ids_to_score:
|
||||
last_index = len(token_ids_to_score) - 1
|
||||
for i, token_ids in enumerate(token_ids_to_score):
|
||||
target_sampling_params = sampling_params if i == last_index \
|
||||
else non_bonus_sampling_params
|
||||
target_seq_group_metadata_list.append(
|
||||
self._create_single_target_seq_group_metadata(
|
||||
input_seq_group_metadata,
|
||||
input_seq_id,
|
||||
next(target_seq_ids_iter),
|
||||
token_ids,
|
||||
sampling_params=target_sampling_params,
|
||||
))
|
||||
|
||||
return target_seq_group_metadata_list
|
||||
|
||||
@staticmethod
|
||||
def _create_single_target_seq_group_metadata(
|
||||
self,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_id: SeqId,
|
||||
target_seq_id: TargetSeqId,
|
||||
token_ids: List[TokenId],
|
||||
sampling_params: SamplingParams,
|
||||
) -> SequenceGroupMetadata:
|
||||
"""Create a single target SequenceGroupMetadata.
|
||||
|
||||
@@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
for data in new_seq_data_dict.values():
|
||||
data.update_num_computed_tokens(data.get_len() - 1)
|
||||
|
||||
if (seq_group_metadata.state is not None
|
||||
and seq_group_metadata.state.generator is not None):
|
||||
generator = torch.Generator(
|
||||
device=seq_group_metadata.state.generator.device)
|
||||
generator.set_state(seq_group_metadata.state.generator.get_state())
|
||||
state = SequenceGroupState(generator=generator)
|
||||
else:
|
||||
state = None
|
||||
|
||||
return SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data=new_seq_data_dict,
|
||||
sampling_params=seq_group_metadata.sampling_params,
|
||||
sampling_params=sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
token_chunk_size=1,
|
||||
state=state,
|
||||
)
|
||||
|
||||
def _split_scoring_output(
|
||||
|
||||
@@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
||||
seq_lens, query_lens = self._prepare_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
|
||||
generators = self.model_runner.get_generators(
|
||||
execute_model_req.finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory)
|
||||
self.model_runner.pin_memory, generators)
|
||||
|
||||
model_outputs = self.model_runner.model.generate_proposals(
|
||||
previous_hidden_states=execute_model_req.previous_hidden_states.
|
||||
|
||||
@@ -38,9 +38,11 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||
(input_tokens, seq_lens,
|
||||
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
generators = self.model_runner.get_generators(
|
||||
execute_model_req.finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory)
|
||||
self.model_runner.pin_memory, generators)
|
||||
|
||||
model_outputs = self.model_runner.model.generate_proposals(
|
||||
input_ids=input_tokens,
|
||||
|
||||
@@ -7,10 +7,9 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
||||
|
||||
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
class NGramWorker(NonLLMProposerWorkerBase):
|
||||
"""NGramWorker provides a light drafter without need for model.
|
||||
|
||||
Current NGramWorker only implements prompt lookup decoding,
|
||||
|
||||
@@ -213,6 +213,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"""
|
||||
self.proposer_worker = proposer_worker
|
||||
self.scorer_worker = scorer_worker
|
||||
scorer_runner = getattr(self.scorer_worker, "model_runner", None)
|
||||
self.generators = scorer_runner.get_generators(
|
||||
) if scorer_runner else None
|
||||
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
||||
self.spec_decode_sampler = spec_decode_sampler
|
||||
self._allow_zero_draft_token_step = allow_zero_draft_token_step
|
||||
@@ -591,20 +594,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||
|
||||
# Sampler arguments
|
||||
sampler_extra_kwargs = {}
|
||||
if isinstance(self.spec_decode_sampler,
|
||||
SpecDecodeStochasticBaseSampler):
|
||||
|
||||
# Get sequence group state
|
||||
generators = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
if (seq_group_metadata.state is not None
|
||||
and seq_group_metadata.state.generator is not None):
|
||||
generators.append(seq_group_metadata.state.generator)
|
||||
else:
|
||||
generators.append(None)
|
||||
|
||||
sampler_extra_kwargs["generators"] = generators
|
||||
sampler_extra_kwargs: Dict[str, Any] = {}
|
||||
if self.generators and isinstance(self.spec_decode_sampler,
|
||||
SpecDecodeStochasticBaseSampler):
|
||||
sampler_extra_kwargs["seeded_seqs"] = {
|
||||
idx: self.generators[sgm.request_id]
|
||||
for idx, sgm in enumerate(seq_group_metadata_list)
|
||||
if sgm.sampling_params.seed is not None
|
||||
}
|
||||
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_probs=proposal_verifier_probs,
|
||||
|
||||
Reference in New Issue
Block a user