[Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (#6971)

This commit is contained in:
William Lin
2024-08-08 22:42:45 -07:00
committed by GitHub
parent 99b4cf5f23
commit 57b7be0e1c
8 changed files with 52 additions and 3 deletions

View File

@@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
def set_include_gpu_probs_tensor(self):
pass
def set_should_modify_greedy_probs_inplace(self):
pass
@torch.inference_mode()
def sampler_output(
self,

View File

@@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.model.sampler.include_gpu_probs_tensor = True
def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)
@torch.inference_mode()
def sampler_output(
self,

View File

@@ -28,6 +28,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
"""Implementation optional"""
pass
def set_should_modify_greedy_probs_inplace(self) -> None:
"""Implementation optional"""
pass
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
"""Proposer worker which does not use a model with kvcache"""

View File

@@ -83,6 +83,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
# Need include_gpu_probs_tensor for multi_step_worker
self._worker.set_include_gpu_probs_tensor()
def set_should_modify_greedy_probs_inplace(self) -> None:
if self._is_dummy:
return
self._worker.set_should_modify_greedy_probs_inplace()
def load_model(self) -> None:
if self._is_dummy:
return

View File

@@ -295,7 +295,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.