[RFC] [Mistral] FP8 format (#10130)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
committed by
GitHub
parent
870c37481e
commit
d366ccc4e3
@@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
mistral_mapping = {
|
||||
"layers": "model.layers",
|
||||
"attention": "self_attn",
|
||||
"qscale_act": "input_scale",
|
||||
"qscale_weight": "weight_scale",
|
||||
"kv_fake_quantizer.qscale_act": "kv_scale",
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
@@ -590,15 +593,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if "wk" in modules:
|
||||
if "wk" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
elif "wq" in modules:
|
||||
elif "wq" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
|
||||
for item in modules:
|
||||
if item in mapping and mapping[item] not in name:
|
||||
num_modules = len(modules)
|
||||
for i in range(num_modules):
|
||||
item = modules[i]
|
||||
next_item = modules[i + 1] if i < num_modules - 1 else None
|
||||
|
||||
combined_item = (f"{item}.{next_item}"
|
||||
if next_item is not None else None)
|
||||
|
||||
if combined_item in mapping:
|
||||
name = name.replace(combined_item, mapping[combined_item])
|
||||
elif item in mapping and mapping[item] not in name:
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
Reference in New Issue
Block a user