[Model][VLM] Decouple weight loading logic for Paligemma (#8269)

This commit is contained in:
Isotr0py
2024-09-08 01:45:44 +08:00
committed by GitHub
parent e807125936
commit 36bf8150cc
2 changed files with 51 additions and 78 deletions

View File

@@ -529,6 +529,12 @@ class SiglipVisionModel(nn.Module):
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)
@@ -544,7 +550,16 @@ class SiglipVisionModel(nn.Module):
if layer_idx >= layer_count:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)