[Speculative Decoding] MLPSpeculator Tensor Parallel support (1/2) (#6050)
Co-authored-by: Sirej Dua <sirej.dua@databricks.com> Co-authored-by: Sirej Dua <Sirej Dua>
This commit is contained in:
@@ -113,24 +113,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
|
||||
disable_bonus_tokens = True
|
||||
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
disable_bonus_tokens = False
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
elif draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "mlp_speculator":
|
||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||
disable_bonus_tokens = False
|
||||
else:
|
||||
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
||||
'parallel_config']
|
||||
draft_tp = draft_parallel_config.tensor_parallel_size
|
||||
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
||||
|
||||
if draft_tp == 1:
|
||||
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
if draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "mlp_speculator":
|
||||
disable_bonus_tokens = False
|
||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||
else:
|
||||
if draft_tp == 1:
|
||||
draft_worker_kwargs[
|
||||
"model_runner_cls"] = TP1DraftModelRunner
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
|
||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||
proposer_worker, draft_tp, target_tp)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user