[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -82,6 +82,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
input_size=conv_kernel_size,
|
||||
output_size=intermediate_size,
|
||||
bias=use_conv_bias,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
@@ -90,7 +91,10 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=use_bias
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=use_bias,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
@@ -98,12 +102,17 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
intermediate_size,
|
||||
time_step_rank + ssm_state_size * 2,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.x_proj",
|
||||
)
|
||||
# time step projection (discretization) -
|
||||
# In the forward we need to apply dt_proj without the bias,
|
||||
# as the bias is added in the selective scan kernel.
|
||||
self.dt_proj = ColumnParallelLinear(
|
||||
time_step_rank, intermediate_size, bias=True, skip_bias_add=True
|
||||
time_step_rank,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
skip_bias_add=True,
|
||||
prefix=f"{prefix}.dt_proj",
|
||||
)
|
||||
|
||||
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||
@@ -136,6 +145,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
hidden_size,
|
||||
bias=use_bias,
|
||||
input_is_parallel=True,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.dt_layernorm = (
|
||||
|
||||
Reference in New Issue
Block a user