[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -127,11 +127,16 @@ class AriaProjectorMLP(nn.Module):
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
output_dim: int,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear_in = ColumnParallelLinear(in_features, hidden_features, bias=False)
|
||||
self.linear_out = RowParallelLinear(hidden_features, output_dim, bias=False)
|
||||
self.linear_in = ColumnParallelLinear(
|
||||
in_features, hidden_features, bias=False, prefix=f"{prefix}.linear_in"
|
||||
)
|
||||
self.linear_out = RowParallelLinear(
|
||||
hidden_features, output_dim, bias=False, prefix=f"{prefix}.linear_out"
|
||||
)
|
||||
self.act = get_act_fn("gelu_new")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -154,7 +159,7 @@ class AriaProjector(nn.Module):
|
||||
A tensor with the shape of (batch_size, query_number, output_dim)
|
||||
"""
|
||||
|
||||
def __init__(self, config: AriaConfig) -> None:
|
||||
def __init__(self, config: AriaConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_to_query_dict = config.projector_patch_to_query_dict
|
||||
@@ -174,7 +179,10 @@ class AriaProjector(nn.Module):
|
||||
|
||||
self.layer_norm = nn.LayerNorm(self.in_features)
|
||||
self.feed_forward = AriaProjectorMLP(
|
||||
self.in_features, self.hidden_features, self.output_dim
|
||||
self.in_features,
|
||||
self.hidden_features,
|
||||
self.output_dim,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -536,7 +544,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vision_tower",
|
||||
)
|
||||
self.multi_modal_projector = AriaProjector(config)
|
||||
self.multi_modal_projector = AriaProjector(
|
||||
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
|
||||
)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.language_model = AriaTextModel(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
|
||||
Reference in New Issue
Block a user