[Model][Speculative Decoding] DeepSeek MTP spec decode (#12755)
Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
@@ -732,13 +732,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
# TODO(simon): support nextn predict layers
|
||||
if hasattr(self.config, "num_nextn_predict_layers"
|
||||
) and self.config.num_nextn_predict_layers > 0:
|
||||
assert self.config.num_nextn_predict_layers == 1
|
||||
layer_idx = self.config.num_hidden_layers
|
||||
if name.startswith(f"model.layers.{layer_idx}"):
|
||||
continue
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
@@ -805,3 +801,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
|
||||
weight_name: str) -> Optional[int]:
|
||||
if hasattr(config,
|
||||
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
|
||||
> 0):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
|
||||
return layer_idx + i
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user