diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index d4e87707c..0b844d149 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -44,7 +44,9 @@ def set_weight_attrs( setattr(weight, key, value) -def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor): +def replace_parameter( + layer: torch.nn.Module, param_name: str, new_data: torch.Tensor | None +): """ Replace a parameter of a layer while maintaining the ability to reload the weight. Called within implementations of the `process_weights_after_loading` method. @@ -54,9 +56,15 @@ def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.T Args: layer: Layer containing parameter to replace param_name: Name of parameter to replace - new_data: New data of the new parameter + new_data: New data of the new parameter, or None to set the parameter to None """ # should not be used on a tied/shared param + + # If new_data is None, set the parameter to None + if new_data is None: + setattr(layer, param_name, None) + return + if isinstance(new_data, torch.nn.Parameter): new_data = new_data.data new_param = torch.nn.Parameter(new_data, requires_grad=False)