[BugFix] Fix EPLB fail for MoeFP4 model with Marlin backend (#33262)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
Ilya Markov
2026-01-29 09:52:11 +01:00
committed by GitHub
parent 31b25f6516
commit 53fc166402

View File

@@ -44,7 +44,9 @@ def set_weight_attrs(
setattr(weight, key, value)
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
def replace_parameter(
layer: torch.nn.Module, param_name: str, new_data: torch.Tensor | None
):
"""
Replace a parameter of a layer while maintaining the ability to reload the weight.
Called within implementations of the `process_weights_after_loading` method.
@@ -54,9 +56,15 @@ def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.T
Args:
layer: Layer containing parameter to replace
param_name: Name of parameter to replace
new_data: New data of the new parameter
new_data: New data of the new parameter, or None to set the parameter to None
"""
# should not be used on a tied/shared param
# If new_data is None, set the parameter to None
if new_data is None:
setattr(layer, param_name, None)
return
if isinstance(new_data, torch.nn.Parameter):
new_data = new_data.data
new_param = torch.nn.Parameter(new_data, requires_grad=False)