[Model][Speculative Decoding] Expand DeepSeek MTP code to support k > n_predict (#13626)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
Benjamin Chislett
2025-02-27 18:28:08 -05:00
committed by GitHub
parent 2e94b9cfbb
commit 9804145cac
6 changed files with 49 additions and 22 deletions

View File

@@ -1978,13 +1978,12 @@ class SpeculativeConfig:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
num_speculative_tokens = n_predict
elif num_speculative_tokens > n_predict:
# Verify provided value doesn't exceed the maximum
# supported by the draft model.
elif num_speculative_tokens > n_predict and \
num_speculative_tokens % n_predict != 0:
# Ensure divisibility for MTP module reuse.
raise ValueError(
"This speculative model supports a maximum of "
f"num_speculative_tokens={n_predict}, but "
f"{num_speculative_tokens=} was provided.")
f"{num_speculative_tokens=} must be divisible by "
f"{n_predict=}")
speculative_draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(