[Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (#6971)
This commit is contained in:
@@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
pass
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self):
|
||||
pass
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
|
||||
@@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
# Need include_gpu_probs_tensor for MultiStepWorker
|
||||
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
|
||||
True)
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
|
||||
@@ -28,6 +28,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
|
||||
"""Implementation optional"""
|
||||
pass
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
"""Implementation optional"""
|
||||
pass
|
||||
|
||||
|
||||
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
|
||||
"""Proposer worker which does not use a model with kvcache"""
|
||||
|
||||
@@ -83,6 +83,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||
# Need include_gpu_probs_tensor for multi_step_worker
|
||||
self._worker.set_include_gpu_probs_tensor()
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
self._worker.set_should_modify_greedy_probs_inplace()
|
||||
|
||||
def load_model(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
@@ -295,7 +295,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"""
|
||||
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
|
||||
) = True
|
||||
(self.scorer_worker.model_runner.model.sampler.
|
||||
should_modify_greedy_probs_inplace) = True
|
||||
self.proposer_worker.set_include_gpu_probs_tensor()
|
||||
self.proposer_worker.set_should_modify_greedy_probs_inplace()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of cache blocks to use.
|
||||
|
||||
Reference in New Issue
Block a user