Add AWQ support for all models (#1714)

This commit is contained in:
Woosuk Kwon
2023-11-18 17:56:47 -08:00
committed by GitHub
parent e946260cf3
commit 8d17774f92
13 changed files with 90 additions and 17 deletions

View File

@@ -129,7 +129,9 @@ class OPTDecoderLayer(nn.Module):
linear_method=linear_method,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.activation_fn = get_act_fn(config.activation_function)
quant_config = getattr(linear_method, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim,
@@ -251,7 +253,7 @@ class OPTDecoder(nn.Module):
inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
for i in range(len(self.layers)):
@@ -266,7 +268,7 @@ class OPTDecoder(nn.Module):
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
hidden_states, _ = self.project_out(hidden_states)
return hidden_states