[Model][Speculative Decoding] DeepSeek MTP spec decode (#12755)

Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Lucia Fang
2025-02-19 01:06:23 -08:00
committed by GitHub
parent 983a40a8bb
commit f525c0be8b
14 changed files with 727 additions and 46 deletions

View File

@@ -763,7 +763,7 @@ class ModelConfig:
def is_deepseek_mla(self) -> bool:
return (hasattr(self.hf_text_config, "model_type")) \
and (self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3'))\
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
and (self.hf_text_config.kv_lora_rank is not None)
def get_head_size(self) -> int:
@@ -856,8 +856,12 @@ class ModelConfig:
def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
if self.hf_text_config.model_type == "deepseek_mtp":
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
@@ -1689,6 +1693,18 @@ class SpeculativeConfig:
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
return hf_config
@staticmethod
def maybe_create_spec_config(
target_model_config: ModelConfig,
@@ -1771,12 +1787,18 @@ class SpeculativeConfig:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
"""
if speculative_model is None:
if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without "
"speculative_model.")
return None
if target_model_config.hf_text_config.model_type \
== "deepseek_v3":
# use the draft model from the same model:
speculative_model = target_model_config.model
else:
raise ValueError(
"num_speculative_tokens was provided without "
"speculative_model.")
else:
return None
if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
@@ -1830,6 +1852,7 @@ class SpeculativeConfig:
max_seq_len_to_capture=target_model_config.
max_seq_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
draft_hf_config = draft_model_config.hf_config
@@ -1846,7 +1869,6 @@ class SpeculativeConfig:
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None:
if num_speculative_tokens is None:
@@ -1960,8 +1982,9 @@ class SpeculativeConfig:
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"MLPSpeculator cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1")
"%s cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1",
draft_hf_config.model_type)
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size