[Bugfix] fix modelopt exclude_modules name mapping (#24178)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -291,6 +291,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
output_size=self.conv_dim,
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
@@ -303,6 +304,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
|
||||
# - because in_proj is a concatenation of 3 weights, we
|
||||
@@ -402,6 +404,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
bias=use_bias,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.norm = Mixer2RMSNormGated(intermediate_size,
|
||||
|
||||
Reference in New Issue
Block a user