[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
This commit is contained in:
@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module):
|
||||
position_embedding: str,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
position_embedding=position_embedding,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
BaiChuanDecoderLayer(config, position_embedding, linear_method)
|
||||
BaiChuanDecoderLayer(config, position_embedding, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
self,
|
||||
config,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = BaiChuanModel(config, position_embedding, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = BaiChuanModel(config, position_embedding, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
if config.hidden_size == 4096: # baichuan2 7b
|
||||
super().__init__(config, "ROPE", linear_method, lora_config)
|
||||
super().__init__(config, "ROPE", quant_config, lora_config)
|
||||
else: # baichuan 13b, baichuan2 13b
|
||||
super().__init__(config, "ALIBI", linear_method, lora_config)
|
||||
super().__init__(config, "ALIBI", quant_config, lora_config)
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__(config, "ROPE", linear_method, lora_config)
|
||||
super().__init__(config, "ROPE", quant_config, lora_config)
|
||||
|
||||
Reference in New Issue
Block a user