[MTP] Validate that MTP weights are actually loaded (#35548)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -415,6 +415,26 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
if not is_fusion_moe_shared_experts_layer:
|
if not is_fusion_moe_shared_experts_layer:
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
|
|
||||||
|
# Validate that weights were loaded for each expected MTP layer.
|
||||||
|
loaded_layers: set[int] = set()
|
||||||
|
for param_name in loaded_params:
|
||||||
|
spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name)
|
||||||
|
if spec_layer is not None:
|
||||||
|
loaded_layers.add(spec_layer)
|
||||||
|
for layer_idx in range(
|
||||||
|
self.model.mtp_start_layer_idx,
|
||||||
|
self.model.mtp_start_layer_idx + self.model.num_mtp_layers,
|
||||||
|
):
|
||||||
|
if layer_idx not in loaded_layers:
|
||||||
|
raise ValueError(
|
||||||
|
f"MTP speculative decoding layer {layer_idx} weights "
|
||||||
|
f"missing from checkpoint. The checkpoint may have "
|
||||||
|
f"been quantized without including the MTP layers. "
|
||||||
|
f"Use a checkpoint that includes MTP layer weights, "
|
||||||
|
f"or disable speculative decoding."
|
||||||
|
)
|
||||||
|
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user