[Model] Allow loading from original Mistral format (#8168)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Patrick von Platen
2024-09-07 01:02:05 +02:00
committed by GitHub
parent 23f322297f
commit 29f49cd6e3
7 changed files with 291 additions and 81 deletions

View File

@@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
}
def __init__(
self,
@@ -472,6 +491,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
@@ -549,3 +570,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
self, name: str,
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
def permute(w, n_heads):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
mapping = self.mistral_mapping
modules = name.split(".")
# rotary embeds should be sliced
if "wk" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
for item in modules:
if item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight