[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2026-01-06 04:03:18 +08:00
committed by GitHub
parent 02dbb933cb
commit 5708297e4e
17 changed files with 181 additions and 40 deletions

View File

@@ -82,6 +82,7 @@ class MambaMixer(MambaBase, CustomOp):
input_size=conv_kernel_size,
output_size=intermediate_size,
bias=use_conv_bias,
prefix=f"{prefix}.conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
@@ -90,7 +91,10 @@ class MambaMixer(MambaBase, CustomOp):
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=use_bias
hidden_size,
[intermediate_size] * 2,
bias=use_bias,
prefix=f"{prefix}.in_proj",
)
# selective projection used to make dt, B and C input dependent
@@ -98,12 +102,17 @@ class MambaMixer(MambaBase, CustomOp):
intermediate_size,
time_step_rank + ssm_state_size * 2,
bias=False,
prefix=f"{prefix}.x_proj",
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(
time_step_rank, intermediate_size, bias=True, skip_bias_add=True
time_step_rank,
intermediate_size,
bias=True,
skip_bias_add=True,
prefix=f"{prefix}.dt_proj",
)
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
@@ -136,6 +145,7 @@ class MambaMixer(MambaBase, CustomOp):
hidden_size,
bias=use_bias,
input_is_parallel=True,
prefix=f"{prefix}.out_proj",
)
self.dt_layernorm = (