[BugFix] Fix TP support for AWQ (#1731)

This commit is contained in:
Woosuk Kwon
2023-11-20 21:42:45 -08:00
committed by GitHub
parent 4bb6b67188
commit cf35d8f3d7
2 changed files with 38 additions and 14 deletions

View File

@@ -129,9 +129,6 @@ class OPTDecoderLayer(nn.Module):
linear_method=linear_method,
)
self.do_layer_norm_before = config.do_layer_norm_before
quant_config = getattr(linear_method, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim,
@@ -142,6 +139,9 @@ class OPTDecoderLayer(nn.Module):
bias=config.enable_bias,
linear_method=linear_method,
)
quant_config = getattr(linear_method, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,