[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
This commit is contained in:
@@ -30,11 +30,12 @@ from transformers import OlmoConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -54,7 +55,7 @@ class OlmoAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OlmoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -79,7 +80,7 @@ class OlmoAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=config.attention_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Rotary embeddings.
|
||||
@@ -99,7 +100,7 @@ class OlmoAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -129,7 +130,7 @@ class OlmoMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OlmoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -141,7 +142,7 @@ class OlmoMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Activation function.
|
||||
@@ -152,7 +153,7 @@ class OlmoMLP(nn.Module):
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -174,13 +175,13 @@ class OlmoDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: OlmoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
# Attention block.
|
||||
self.self_attn = OlmoAttention(config, linear_method)
|
||||
self.self_attn = OlmoAttention(config, quant_config)
|
||||
|
||||
# MLP block.
|
||||
self.mlp = OlmoMLP(config, linear_method)
|
||||
self.mlp = OlmoMLP(config, quant_config)
|
||||
|
||||
# LayerNorm
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
@@ -216,14 +217,14 @@ class OlmoModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: OlmoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
OlmoDecoderLayer(config, linear_method)
|
||||
OlmoDecoderLayer(config, quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = nn.LayerNorm(config.hidden_size,
|
||||
@@ -270,11 +271,10 @@ class OlmoForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: OlmoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = OlmoModel(config, linear_method)
|
||||
self.model = OlmoModel(config, quant_config)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user