Remove all cases of fmt: on/off (#26253)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -516,14 +516,18 @@ class VoxtralForConditionalGeneration(
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
# fmt: off
|
||||
remapping_rules = [
|
||||
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
||||
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
||||
(r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"), # noqa: E501
|
||||
(r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"), # noqa: E501
|
||||
(
|
||||
r"audio_language_adapter\.0\.weight",
|
||||
r"audio_language_adapter.w_in.weight",
|
||||
),
|
||||
(
|
||||
r"audio_language_adapter\.2\.weight",
|
||||
r"audio_language_adapter.w_out.weight",
|
||||
),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
audio_params = dict(
|
||||
nn.ModuleDict(
|
||||
@@ -678,19 +682,44 @@ class AudioLanguageAdapter(nn.Module):
|
||||
class VoxtralEncoderModel(nn.Module):
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
# fmt: off
|
||||
mistral_remapping = [
|
||||
(r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501
|
||||
(r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501
|
||||
(
|
||||
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
|
||||
r"whisper_encoder.conv1.\1",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
|
||||
r"whisper_encoder.conv2.\1",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.self_attn.out_proj.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.self_attn_layer_norm.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.mlp.fc1.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.mlp.fc2.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
|
||||
r"whisper_encoder.layers.\1.final_layer_norm.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.norm\.(weight|bias)",
|
||||
r"whisper_encoder.layer_norm.\1",
|
||||
),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user