[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
This commit is contained in:
@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -101,13 +102,13 @@ class CohereMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
@@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -158,13 +159,13 @@ class CohereAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CohereConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = CohereAttention(config, linear_method=linear_method)
|
||||
self.self_attn = CohereAttention(config, quant_config=quant_config)
|
||||
|
||||
self.mlp = CohereMLP(config, linear_method=linear_method)
|
||||
self.mlp = CohereMLP(config, quant_config=quant_config)
|
||||
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@@ -257,7 +258,7 @@ class CohereModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -265,7 +266,7 @@ class CohereModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
CohereDecoderLayer(config, linear_method=linear_method)
|
||||
CohereDecoderLayer(config, quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
||||
@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.quant_config = quant_config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
scale=config.logit_scale)
|
||||
self.model = CohereModel(config, linear_method)
|
||||
self.model = CohereModel(config, quant_config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user