[Quantization] FP8 Weight Reloading for Quantized RL Rollout (#28480)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -50,6 +50,31 @@ def set_weight_attrs(
|
||||
setattr(weight, key, value)
|
||||
|
||||
|
||||
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
|
||||
"""
|
||||
Replace a parameter of a layer while maintaining the ability to reload the weight.
|
||||
Called within implementations of the `process_weights_after_loading` method.
|
||||
|
||||
This function should not be called on weights which are tied/shared
|
||||
|
||||
Args:
|
||||
layer: Layer containing parameter to replace
|
||||
param_name: Name of parameter to replace
|
||||
new_data: New data of the new parameter
|
||||
"""
|
||||
# should not be used on a tied/shared param
|
||||
if isinstance(new_data, torch.nn.Parameter):
|
||||
new_data = new_data.data
|
||||
new_param = torch.nn.Parameter(new_data, requires_grad=False)
|
||||
|
||||
old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
|
||||
if old_param is not None and hasattr(old_param, "weight_loader"):
|
||||
weight_loader = old_param.weight_loader
|
||||
set_weight_attrs(new_param, {"weight_loader": weight_loader})
|
||||
|
||||
setattr(layer, param_name, new_param)
|
||||
|
||||
|
||||
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
||||
parent_map = getattr(model, "packed_modules_mapping", None)
|
||||
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
|
||||
|
||||
Reference in New Issue
Block a user