[Spec Decode] (1/2) Remove batch expansion (#8839)

This commit is contained in:
Lily Liu
2024-10-01 16:04:42 -07:00
committed by GitHub
parent 22f5851b80
commit 1570203864
29 changed files with 531 additions and 99 deletions

View File

@@ -162,6 +162,7 @@ class EngineArgs:
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_disable_mqa_scorer: Optional[bool] = False
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
@@ -640,6 +641,12 @@ class EngineArgs:
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-disable-mqa-scorer',
action='store_true',
help=
'If set to True, the MQA scorer will be disabled in speculative '
' and fall back to batch expansion')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
@@ -970,6 +977,7 @@ class EngineArgs:
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,