Enable more models to inference based on LoRA (#3382)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Jee Li
2024-03-26 09:09:31 +08:00
committed by GitHub
parent dfeb2ecc3a
commit 8af890a865
10 changed files with 401 additions and 44 deletions

View File

@@ -26,6 +26,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -282,11 +283,30 @@ class BaiChuanModel(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module):
packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
def __init__(
self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
@@ -371,19 +391,25 @@ class BaiChuanBaseForCausalLM(nn.Module):
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B."""
def __init__(self,
config,
linear_method: Optional[LinearMethodBase] = None):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", linear_method)
super().__init__(config, "ROPE", linear_method, lora_config)
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", linear_method)
super().__init__(config, "ALIBI", linear_method, lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 7B."""
def __init__(self,
config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ROPE", linear_method)
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__(config, "ROPE", linear_method, lora_config)