Add LoRA support for Mixtral (#2831)

* add mixtral lora support

* formatting

* fix incorrectly ported logic

* polish tests

* minor fixes and refactoring

* minor fixes

* formatting

* rename and remove redundant logic

* refactoring

* refactoring

* minor fix

* minor refactoring

* fix code smell
This commit is contained in:
Terry
2024-02-13 15:55:45 -08:00
committed by GitHub
parent 317b29de0f
commit 2a543d6efe
10 changed files with 251 additions and 121 deletions

View File

@@ -265,7 +265,32 @@ class MistralModel(nn.Module):
class MistralForCausalLM(nn.Module):
supports_lora = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,