[Misc][Refactor] Generalize linear_method to be quant_method (#4373)

This commit is contained in:
Cody Yu
2024-04-26 13:41:14 -07:00
committed by GitHub
parent 603ad84815
commit a62aaf1df5
45 changed files with 759 additions and 713 deletions

View File

@@ -11,10 +11,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -42,7 +43,7 @@ class MPTAttention(nn.Module):
def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
@@ -65,7 +66,7 @@ class MPTAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=not config.no_bias,
linear_method=linear_method,
quant_config=quant_config,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
@@ -74,7 +75,7 @@ class MPTAttention(nn.Module):
self.d_model,
self.d_model,
bias=not config.no_bias,
linear_method=linear_method,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
@@ -133,7 +134,7 @@ class MPTMLP(nn.Module):
def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
@@ -143,15 +144,15 @@ class MPTMLP(nn.Module):
hidden_size,
intermediate_size,
bias=not config.no_bias,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, intermediate_size)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=not config.no_bias,
linear_method=linear_method,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -166,14 +167,14 @@ class MPTBlock(nn.Module):
def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, linear_method)
self.attn = MPTAttention(config, quant_config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, linear_method)
self.ffn = MPTMLP(config, quant_config)
def forward(
self,
@@ -201,7 +202,7 @@ class MPTModel(nn.Module):
def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
assert config.embedding_fraction == 1.0
@@ -212,7 +213,7 @@ class MPTModel(nn.Module):
config.d_model,
)
self.blocks = nn.ModuleList(
[MPTBlock(config, linear_method) for _ in range(config.n_layers)])
[MPTBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
@@ -246,14 +247,14 @@ class MPTForCausalLM(nn.Module):
def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
assert config.tie_word_embeddings
self.linear_method = linear_method
self.quant_config = quant_config
self.transformer = MPTModel(config, linear_method)
self.transformer = MPTModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()