[Bugfix] Fix k_proj's bias for whisper self attention (#12342)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -729,7 +729,22 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
|
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
|
||||||
loaded_weights = [(name, loaded_weight)
|
|
||||||
for name, loaded_weight in weights]
|
|
||||||
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
|
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
|
||||||
return loader.load_weights(loaded_weights, mapper=mapper)
|
# add fake zeros bias for k_proj to state_dict
|
||||||
|
weights = _create_fake_bias_for_k_proj(weights)
|
||||||
|
return loader.load_weights(weights, mapper=mapper)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_fake_bias_for_k_proj(
|
||||||
|
weights: Iterable[Tuple[str, torch.Tensor]]
|
||||||
|
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Create full zeros bias for k_proj weight in self-attention layers.
|
||||||
|
So that the bias for k_proj in qkv_proj can be initialized with zeros.
|
||||||
|
"""
|
||||||
|
for name, weight in weights:
|
||||||
|
if ".self_attn.k_proj.weight" in name:
|
||||||
|
bias = torch.zeros(weight.size(0))
|
||||||
|
bias_name = name.replace("weight", "bias")
|
||||||
|
yield from [(name, weight), (bias_name, bias)]
|
||||||
|
yield name, weight
|
||||||
|
|||||||
Reference in New Issue
Block a user