[Model][Speculative Decoding] DeepSeek MTP spec decode (#12755)
Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
@@ -763,7 +763,7 @@ class ModelConfig:
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
return (hasattr(self.hf_text_config, "model_type")) \
|
||||
and (self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3'))\
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
|
||||
and (self.hf_text_config.kv_lora_rank is not None)
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
@@ -856,8 +856,12 @@ class ModelConfig:
|
||||
def get_layers_start_end_indices(
|
||||
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
if self.hf_text_config.model_type == "deepseek_mtp":
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_nextn_predict_layers", 0)
|
||||
else:
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
|
||||
pp_size = parallel_config.pipeline_parallel_size
|
||||
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
|
||||
@@ -1689,6 +1693,18 @@ class SpeculativeConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
if hf_config.model_type == "deepseek_v3":
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["DeepSeekMTPModel"]
|
||||
})
|
||||
return hf_config
|
||||
|
||||
@staticmethod
|
||||
def maybe_create_spec_config(
|
||||
target_model_config: ModelConfig,
|
||||
@@ -1771,12 +1787,18 @@ class SpeculativeConfig:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
the necessary conditions are met, else None.
|
||||
"""
|
||||
|
||||
if speculative_model is None:
|
||||
if num_speculative_tokens is not None:
|
||||
raise ValueError("num_speculative_tokens was provided without "
|
||||
"speculative_model.")
|
||||
return None
|
||||
if target_model_config.hf_text_config.model_type \
|
||||
== "deepseek_v3":
|
||||
# use the draft model from the same model:
|
||||
speculative_model = target_model_config.model
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided without "
|
||||
"speculative_model.")
|
||||
else:
|
||||
return None
|
||||
|
||||
if (speculative_disable_by_batch_size is not None
|
||||
and speculative_disable_by_batch_size < 2):
|
||||
@@ -1830,6 +1852,7 @@ class SpeculativeConfig:
|
||||
max_seq_len_to_capture=target_model_config.
|
||||
max_seq_len_to_capture,
|
||||
max_logprobs=target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
|
||||
draft_hf_config = draft_model_config.hf_config
|
||||
@@ -1846,7 +1869,6 @@ class SpeculativeConfig:
|
||||
if (num_speculative_tokens is not None
|
||||
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
||||
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
|
||||
|
||||
n_predict = getattr(draft_hf_config, "n_predict", None)
|
||||
if n_predict is not None:
|
||||
if num_speculative_tokens is None:
|
||||
@@ -1960,8 +1982,9 @@ class SpeculativeConfig:
|
||||
speculative_draft_tensor_parallel_size = 1
|
||||
if target_parallel_config.tensor_parallel_size > 1:
|
||||
logger.warning(
|
||||
"MLPSpeculator cannot currently be run with tp>1; "
|
||||
"setting speculative_draft_tensor_parallel_size=1")
|
||||
"%s cannot currently be run with tp>1; "
|
||||
"setting speculative_draft_tensor_parallel_size=1",
|
||||
draft_hf_config.model_type)
|
||||
else:
|
||||
speculative_draft_tensor_parallel_size = \
|
||||
target_parallel_config.tensor_parallel_size
|
||||
|
||||
Reference in New Issue
Block a user