[BUGFIX] Raise an error for no draft token case when draft_tp>1 (#6369)
This commit is contained in:
@@ -22,6 +22,9 @@ class SpeculativeProposals:
|
||||
# The valid length of each proposal; can be zero.
|
||||
proposal_lens: torch.Tensor
|
||||
|
||||
# A flag to mark that there's no available proposals
|
||||
no_proposals: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeProposals("
|
||||
f"proposal_token_ids={self.proposal_token_ids}, "
|
||||
|
||||
@@ -109,6 +109,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
allow_zero_draft_token_step = True
|
||||
ngram_prompt_lookup_max = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
@@ -133,6 +134,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
if draft_tp == 1:
|
||||
draft_worker_kwargs[
|
||||
"model_runner_cls"] = TP1DraftModelRunner
|
||||
else:
|
||||
allow_zero_draft_token_step = False
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
|
||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||
@@ -155,10 +158,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
logger.info("Configuring SpecDecodeWorker with sampler=%s",
|
||||
type(spec_decode_sampler))
|
||||
|
||||
return SpecDecodeWorker(proposer_worker,
|
||||
scorer_worker,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
spec_decode_sampler=spec_decode_sampler)
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
allow_zero_draft_token_step=allow_zero_draft_token_step)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -167,6 +172,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
disable_by_batch_size: Optional[int] = None,
|
||||
allow_zero_draft_token_step: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
Create a SpecDecodeWorker.
|
||||
@@ -187,11 +193,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
disable speculative decoding for new incoming requests.
|
||||
metrics_collector: Helper class for collecting metrics; can be set
|
||||
for testing purposes.
|
||||
allow_zero_draft_token_step: whether to allow a step where the draft
|
||||
model generates no draft token; should disallow when the tp of
|
||||
draft model is larger than 1 (TODO: #5814)
|
||||
"""
|
||||
self.proposer_worker = proposer_worker
|
||||
self.scorer_worker = scorer_worker
|
||||
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
||||
self.spec_decode_sampler = spec_decode_sampler
|
||||
self._allow_zero_draft_token_step = allow_zero_draft_token_step
|
||||
self._metrics = AsyncMetricsCollector(
|
||||
self.spec_decode_sampler
|
||||
) if metrics_collector is None else metrics_collector
|
||||
@@ -461,6 +471,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
execute_model_req, self._seq_with_bonus_token_in_last_step)
|
||||
|
||||
if not self._allow_zero_draft_token_step and proposals.no_proposals:
|
||||
#TODO: Fix it #5814
|
||||
raise RuntimeError("Cannot handle cases where distributed draft "
|
||||
"workers generate no tokens")
|
||||
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
execute_model_req,
|
||||
proposals,
|
||||
|
||||
@@ -108,7 +108,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
)
|
||||
no_proposals=maybe_sampler_output is None)
|
||||
|
||||
return proposals
|
||||
|
||||
|
||||
Reference in New Issue
Block a user