[Spec Decode] (1/2) Remove batch expansion (#8839)

This commit is contained in:
Lily Liu
2024-10-01 16:04:42 -07:00
committed by GitHub
parent 22f5851b80
commit 1570203864
29 changed files with 531 additions and 99 deletions

View File

@@ -1116,6 +1116,7 @@ class SpeculativeConfig:
speculative_model_quantization: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_disable_mqa_scorer: Optional[bool],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
@@ -1150,6 +1151,9 @@ class SpeculativeConfig:
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
scorer for the speculative model and fall back to batch
expansion for scoring.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
@@ -1304,6 +1308,7 @@ class SpeculativeConfig:
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
speculative_disable_mqa_scorer,
speculative_disable_by_batch_size,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
@@ -1400,6 +1405,7 @@ class SpeculativeConfig:
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
speculative_disable_mqa_scorer: Optional[bool],
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
@@ -1446,6 +1452,7 @@ class SpeculativeConfig:
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens
self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
self.speculative_disable_by_batch_size = \
speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0