Add LoRA support for Mixtral (#2831)
* add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell
This commit is contained in:
@@ -269,7 +269,32 @@ class LlamaModel(nn.Module):
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
supports_lora = True
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -281,11 +306,11 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
|
||||
unpadded_vocab_size = config.vocab_size
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
unpadded_vocab_size,
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
@@ -293,7 +318,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
|
||||
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user