[Bugfix][MTP] Fix GLM4 MoE fp8 loading with MTP on (#31757)
Signed-off-by: Andy Liu <andyliu@roblox.com>
This commit is contained in:
@@ -106,7 +106,7 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert inputs_embeds is not None
|
assert inputs_embeds is not None
|
||||||
# masking inputs at position 0, as not needed by MTP
|
# masking inputs at position 0, as not needed by MTP
|
||||||
inputs_embeds[positions == 0] = 0
|
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
|
||||||
inputs_embeds = self.enorm(inputs_embeds)
|
inputs_embeds = self.enorm(inputs_embeds)
|
||||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||||
|
|
||||||
@@ -268,6 +268,11 @@ class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
|
|||||||
if spec_layer is None:
|
if spec_layer is None:
|
||||||
continue
|
continue
|
||||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||||
|
# Some checkpoints include weight scale tensors for the LM head even
|
||||||
|
# when the quantized head isn't built. Skip them if the model does
|
||||||
|
# not expose a matching parameter to avoid KeyError during load.
|
||||||
|
if name.endswith(".weight_scale") and name not in params_dict:
|
||||||
|
continue
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
# Skip non-stacked layers and experts (experts handled below).
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
|
|||||||
Reference in New Issue
Block a user