[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
This commit is contained in:
@@ -27,11 +27,12 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@@ -60,7 +61,7 @@ class OPTAttention(nn.Module):
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@@ -77,13 +78,13 @@ class OPTAttention(nn.Module):
|
||||
self.head_dim,
|
||||
total_num_heads,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
@@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -116,7 +117,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
bias=config.enable_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
|
||||
@@ -127,16 +128,16 @@ class OPTDecoderLayer(nn.Module):
|
||||
self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.activation_fn = get_act_fn(config.activation_function,
|
||||
quant_config, config.ffn_dim)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
@@ -181,7 +182,7 @@ class OPTDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -202,7 +203,7 @@ class OPTDecoder(nn.Module):
|
||||
self.project_out = ReplicatedLinear(config.hidden_size,
|
||||
config.word_embed_proj_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.project_out = None
|
||||
|
||||
@@ -210,7 +211,7 @@ class OPTDecoder(nn.Module):
|
||||
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
@@ -226,7 +227,7 @@ class OPTDecoder(nn.Module):
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
OPTDecoderLayer(config, linear_method)
|
||||
OPTDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@@ -259,10 +260,10 @@ class OPTModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = OPTDecoder(config, linear_method)
|
||||
self.decoder = OPTDecoder(config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -279,12 +280,12 @@ class OPTForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = OPTModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = OPTModel(config, quant_config)
|
||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
Reference in New Issue
Block a user