[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
This commit is contained in:
@@ -34,12 +34,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -56,18 +57,18 @@ class DeepseekMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
@@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -103,7 +104,7 @@ class DeepseekMoE(nn.Module):
|
||||
DeepseekMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False)
|
||||
for idx in range(self.n_routed_experts)
|
||||
])
|
||||
@@ -112,7 +113,7 @@ class DeepseekMoE(nn.Module):
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
self.n_routed_experts,
|
||||
bias=False,
|
||||
linear_method=None)
|
||||
quant_config=None)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
@@ -121,7 +122,7 @@ class DeepseekMoE(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
@@ -177,7 +178,7 @@ class DeepseekAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -208,14 +209,14 @@ class DeepseekAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -251,7 +252,7 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_idx: int,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -266,18 +267,18 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0):
|
||||
self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
|
||||
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
|
||||
else:
|
||||
self.mlp = DeepseekMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@@ -320,7 +321,7 @@ class DeepseekModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -331,9 +332,7 @@ class DeepseekModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
DeepseekDecoderLayer(config,
|
||||
layer_idx,
|
||||
linear_method=linear_method)
|
||||
DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = DeepseekModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
Reference in New Issue
Block a user