[RFC] [Mistral] FP8 format (#10130)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Patrick von Platen
2025-02-08 22:12:53 +01:00
committed by GitHub
parent 870c37481e
commit d366ccc4e3
4 changed files with 55 additions and 12 deletions

View File

@@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
@@ -590,15 +593,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
modules = name.split(".")
# rotary embeds should be sliced
if "wk" in modules:
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules:
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
for item in modules:
if item in mapping and mapping[item] not in name:
num_modules = len(modules)
for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None
combined_item = (f"{item}.{next_item}"
if next_item is not None else None)
if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight