[MTP] Validate that MTP weights are actually loaded (#35548)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-02-27 23:27:40 -05:00
committed by GitHub
parent fd68cd132b
commit 2562e0271e

View File

@@ -415,6 +415,26 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
# Validate that weights were loaded for each expected MTP layer.
loaded_layers: set[int] = set()
for param_name in loaded_params:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name)
if spec_layer is not None:
loaded_layers.add(spec_layer)
for layer_idx in range(
self.model.mtp_start_layer_idx,
self.model.mtp_start_layer_idx + self.model.num_mtp_layers,
):
if layer_idx not in loaded_layers:
raise ValueError(
f"MTP speculative decoding layer {layer_idx} weights "
f"missing from checkpoint. The checkpoint may have "
f"been quantized without including the MTP layers. "
f"Use a checkpoint that includes MTP layer weights, "
f"or disable speculative decoding."
)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: