[MTP] Validate that MTP weights are actually loaded (#35548)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -415,6 +415,26 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
|
||||
weight_loader(param, loaded_weight)
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
|
||||
# Validate that weights were loaded for each expected MTP layer.
|
||||
loaded_layers: set[int] = set()
|
||||
for param_name in loaded_params:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name)
|
||||
if spec_layer is not None:
|
||||
loaded_layers.add(spec_layer)
|
||||
for layer_idx in range(
|
||||
self.model.mtp_start_layer_idx,
|
||||
self.model.mtp_start_layer_idx + self.model.num_mtp_layers,
|
||||
):
|
||||
if layer_idx not in loaded_layers:
|
||||
raise ValueError(
|
||||
f"MTP speculative decoding layer {layer_idx} weights "
|
||||
f"missing from checkpoint. The checkpoint may have "
|
||||
f"been quantized without including the MTP layers. "
|
||||
f"Use a checkpoint that includes MTP layer weights, "
|
||||
f"or disable speculative decoding."
|
||||
)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user