[Model] Allow loading from original Mistral format (#8168)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
23f322297f
commit
29f49cd6e3
@@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
# Mistral/Llama models can also be loaded with --load-format mistral
|
||||
# from consolidated.safetensors checkpoints
|
||||
mistral_mapping = {
|
||||
"layers": "model.layers",
|
||||
"attention": "self_attn",
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
"wo": "o_proj",
|
||||
"attention_norm": "input_layernorm",
|
||||
"feed_forward": "mlp",
|
||||
"w1": "gate_proj",
|
||||
"w2": "down_proj",
|
||||
"w3": "up_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"tok_embeddings": "model.embed_tokens",
|
||||
"output": "lm_head",
|
||||
"norm": "model.norm"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -472,6 +491,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
@@ -549,3 +570,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
# This function is used to remap the mistral format as
|
||||
# used by Mistral and Llama <=2
|
||||
def maybe_remap_mistral(
|
||||
self, name: str,
|
||||
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w, n_heads):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
|
||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||
|
||||
mapping = self.mistral_mapping
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if "wk" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
elif "wq" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
|
||||
for item in modules:
|
||||
if item in mapping and mapping[item] not in name:
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
Reference in New Issue
Block a user