[Speculative Decoding] MLPSpeculator Tensor Parallel support (1/2) (#6050)

Co-authored-by: Sirej Dua <sirej.dua@databricks.com>
Co-authored-by: Sirej Dua <Sirej Dua>
This commit is contained in:
Sirej Dua
2024-07-02 07:20:29 -07:00
committed by GitHub
parent 31354e563f
commit 15aba081f3
3 changed files with 35 additions and 25 deletions

View File

@@ -957,12 +957,6 @@ class SpeculativeConfig:
)
draft_hf_config = draft_model_config.hf_config
if (draft_hf_config.model_type == "mlp_speculator"
and target_parallel_config.world_size != 1):
# MLPSpeculator TP support will be added very soon
raise ValueError(
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1).")
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):