[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2024-11-07 17:15:14 +01:00
committed by GitHub
parent ae62fd17c0
commit 9d43afcc53
17 changed files with 476 additions and 146 deletions

View File

@@ -192,7 +192,6 @@ class ModelConfig:
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling, rope_theta,
config_format)
@@ -1317,13 +1316,6 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}")
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if enable_chunked_prefill:
raise ValueError(
"Speculative decoding and chunked prefill are "
f"currently mutually exclusive ({enable_chunked_prefill=}).")
# TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported.
draft_revision = None
@@ -1390,6 +1382,12 @@ class SpeculativeConfig:
f"num_speculative_tokens={n_predict}, but "
f"{num_speculative_tokens=} was provided.")
if enable_chunked_prefill and draft_hf_config.model_type in (
"medusa", "mlp_speculator", "eagle"):
raise ValueError(
"Chunked prefill and hidden-state based draft models are "
"not compatible.")
draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,