[BugFix] Fix EPLB fail for MoeFP4 model with Marlin backend (#33262)
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
@@ -44,7 +44,9 @@ def set_weight_attrs(
|
||||
setattr(weight, key, value)
|
||||
|
||||
|
||||
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
|
||||
def replace_parameter(
|
||||
layer: torch.nn.Module, param_name: str, new_data: torch.Tensor | None
|
||||
):
|
||||
"""
|
||||
Replace a parameter of a layer while maintaining the ability to reload the weight.
|
||||
Called within implementations of the `process_weights_after_loading` method.
|
||||
@@ -54,9 +56,15 @@ def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.T
|
||||
Args:
|
||||
layer: Layer containing parameter to replace
|
||||
param_name: Name of parameter to replace
|
||||
new_data: New data of the new parameter
|
||||
new_data: New data of the new parameter, or None to set the parameter to None
|
||||
"""
|
||||
# should not be used on a tied/shared param
|
||||
|
||||
# If new_data is None, set the parameter to None
|
||||
if new_data is None:
|
||||
setattr(layer, param_name, None)
|
||||
return
|
||||
|
||||
if isinstance(new_data, torch.nn.Parameter):
|
||||
new_data = new_data.data
|
||||
new_param = torch.nn.Parameter(new_data, requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user