[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2026-01-06 04:03:18 +08:00
committed by GitHub
parent 02dbb933cb
commit 5708297e4e
17 changed files with 181 additions and 40 deletions

View File

@@ -127,11 +127,16 @@ class AriaProjectorMLP(nn.Module):
in_features: int,
hidden_features: int,
output_dim: int,
prefix: str = "",
) -> None:
super().__init__()
self.linear_in = ColumnParallelLinear(in_features, hidden_features, bias=False)
self.linear_out = RowParallelLinear(hidden_features, output_dim, bias=False)
self.linear_in = ColumnParallelLinear(
in_features, hidden_features, bias=False, prefix=f"{prefix}.linear_in"
)
self.linear_out = RowParallelLinear(
hidden_features, output_dim, bias=False, prefix=f"{prefix}.linear_out"
)
self.act = get_act_fn("gelu_new")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -154,7 +159,7 @@ class AriaProjector(nn.Module):
A tensor with the shape of (batch_size, query_number, output_dim)
"""
def __init__(self, config: AriaConfig) -> None:
def __init__(self, config: AriaConfig, prefix: str = "") -> None:
super().__init__()
self.patch_to_query_dict = config.projector_patch_to_query_dict
@@ -174,7 +179,10 @@ class AriaProjector(nn.Module):
self.layer_norm = nn.LayerNorm(self.in_features)
self.feed_forward = AriaProjectorMLP(
self.in_features, self.hidden_features, self.output_dim
self.in_features,
self.hidden_features,
self.output_dim,
prefix=f"{prefix}.feed_forward",
)
def forward(
@@ -536,7 +544,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config=quant_config,
prefix=f"{prefix}.vision_tower",
)
self.multi_modal_projector = AriaProjector(config)
self.multi_modal_projector = AriaProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
)
self.vocab_size = config.text_config.vocab_size
self.language_model = AriaTextModel(
vllm_config=vllm_config.with_hf_config(config.text_config),