Add LoRA support for Mixtral (#2831)

* add mixtral lora support

* formatting

* fix incorrectly ported logic

* polish tests

* minor fixes and refactoring

* minor fixes

* formatting

* rename and remove redundant logic

* refactoring

* refactoring

* minor fix

* minor refactoring

* fix code smell
This commit is contained in:
Terry
2024-02-13 15:55:45 -08:00
committed by GitHub
parent 317b29de0f
commit 2a543d6efe
10 changed files with 251 additions and 121 deletions

View File

@@ -269,7 +269,32 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module):
supports_lora = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
@@ -281,11 +306,11 @@ class LlamaForCausalLM(nn.Module):
self.config = config
self.linear_method = linear_method
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
@@ -293,7 +318,7 @@ class LlamaForCausalLM(nn.Module):
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
def forward(
self,