[Misc][Refactor] Generalize linear_method to be quant_method (#4373)

This commit is contained in:
Cody Yu
2024-04-26 13:41:14 -07:00
committed by GitHub
parent 603ad84815
commit a62aaf1df5
45 changed files with 759 additions and 713 deletions

View File

@@ -27,11 +27,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -60,7 +61,7 @@ class OPTAttention(nn.Module):
embed_dim: int,
num_heads: int,
bias: bool = True,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_dim = embed_dim
@@ -77,13 +78,13 @@ class OPTAttention(nn.Module):
self.head_dim,
total_num_heads,
bias=bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
@@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module):
def __init__(
self,
config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -116,7 +117,7 @@ class OPTDecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
bias=config.enable_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.do_layer_norm_before = config.do_layer_norm_before
@@ -127,16 +128,16 @@ class OPTDecoderLayer(nn.Module):
self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
@@ -181,7 +182,7 @@ class OPTDecoder(nn.Module):
def __init__(
self,
config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -202,7 +203,7 @@ class OPTDecoder(nn.Module):
self.project_out = ReplicatedLinear(config.hidden_size,
config.word_embed_proj_dim,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
else:
self.project_out = None
@@ -210,7 +211,7 @@ class OPTDecoder(nn.Module):
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
config.hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
else:
self.project_in = None
@@ -226,7 +227,7 @@ class OPTDecoder(nn.Module):
self.final_layer_norm = None
self.layers = nn.ModuleList([
OPTDecoderLayer(config, linear_method)
OPTDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
@@ -259,10 +260,10 @@ class OPTModel(nn.Module):
def __init__(
self,
config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.decoder = OPTDecoder(config, linear_method)
self.decoder = OPTDecoder(config, quant_config)
def forward(
self,
@@ -279,12 +280,12 @@ class OPTForCausalLM(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = OPTModel(config, linear_method)
self.quant_config = quant_config
self.model = OPTModel(config, quant_config)
self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()